Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

horizon: Merge horizon-db-optimizations into master #4400

Merged
merged 7 commits into from
May 24, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions services/horizon/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -3,6 +3,13 @@
All notable changes to this project will be documented in this
file. This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

- Querying claimable balances has been optimized ([4385](https://github.com/stellar/go/pull/4385)).
- Querying trade aggregations has been optimized ([4389](https://github.com/stellar/go/pull/4389)).
- Postgres connections for non ingesting Horizon instances are now configured to timeout on long running queries / transactions ([4390](https://github.com/stellar/go/pull/4390)).
- Added `disable-path-finding` Horizon flag to disable the path finding endpoints. This flag should be enabled on ingesting Horizon instances which do not serve HTTP traffic ([4399](https://github.com/stellar/go/pull/4399)).

## V2.17.0

This is the final release after the [release candidate](v2.17.0-release-candidate), including some small additional changes:
4 changes: 3 additions & 1 deletion services/horizon/internal/app.go
Original file line number Diff line number Diff line change
@@ -90,7 +90,9 @@ func (a *App) Serve() error {
}

go a.run()
go a.orderBookStream.Run(a.ctx)
if !a.config.DisablePathFinding {
go a.orderBookStream.Run(a.ctx)
}

// WaitGroup for all go routines. Makes sure that DB is closed when
// all services gracefully shutdown.
2 changes: 2 additions & 0 deletions services/horizon/internal/config.go
Original file line number Diff line number Diff line change
@@ -53,6 +53,8 @@ type Config struct {
// DisablePoolPathFinding configures horizon to run path finding without including liquidity pools
// in the path finding search.
DisablePoolPathFinding bool
// DisablePathFinding configures horizon without the path finding endpoint.
DisablePathFinding bool
// MaxPathFindingRequests is the maximum number of path finding requests horizon will allow
// in a 1-second period. A value of 0 disables the limit.
MaxPathFindingRequests uint
Original file line number Diff line number Diff line change
@@ -220,7 +220,7 @@ func (q *Q) GetClaimableBalances(ctx context.Context, query ClaimableBalancesQue
sql = sql.
Prefix("WITH cb AS (").
Suffix(
") select "+claimableBalancesSelectStatement+" from cb LIMIT ?",
"LIMIT ?) select "+claimableBalancesSelectStatement+" from cb",
query.PageQuery.Limit,
)

41 changes: 34 additions & 7 deletions services/horizon/internal/db2/history/offers.go
Original file line number Diff line number Diff line change
@@ -2,13 +2,16 @@ package history

import (
"context"
"database/sql"

sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"

"github.com/stellar/go/support/errors"
)

const offersBatchSize = 50000

// QOffers defines offer related queries.
type QOffers interface {
StreamAllOffers(ctx context.Context, callback func(Offer) error) error
@@ -83,28 +86,52 @@ func (q *Q) GetOffers(ctx context.Context, query OffersQuery) ([]Offer, error) {

// StreamAllOffers loads all non deleted offers
func (q *Q) StreamAllOffers(ctx context.Context, callback func(Offer) error) error {
if tx := q.GetTx(); tx == nil {
return errors.New("cannot be called outside of a transaction")
}
if opts := q.GetTxOptions(); opts == nil || !opts.ReadOnly || opts.Isolation != sql.LevelRepeatableRead {
return errors.New("should only be called in a repeatable read transaction")
}

lastID := int64(0)
for {
nextID, err := q.streamAllOffersBatch(ctx, lastID, offersBatchSize, callback)
if err != nil {
return err
}
if lastID == nextID {
return nil
}
lastID = nextID
}
}

func (q *Q) streamAllOffersBatch(ctx context.Context, lastId int64, limit uint64, callback func(Offer) error) (int64, error) {
var rows *sqlx.Rows
var err error

if rows, err = q.Query(ctx, selectOffers.Where("deleted = ?", false)); err != nil {
return errors.Wrap(err, "could not run all offers select query")
rows, err = q.Query(ctx, selectOffers.
Where("deleted = ?", false).
Where("offer_id > ? ", lastId).
OrderBy("offer_id asc").Limit(limit))
if err != nil {
return 0, errors.Wrap(err, "could not run all offers select query")
}

defer rows.Close()

for rows.Next() {
offer := Offer{}
if err = rows.StructScan(&offer); err != nil {
return errors.Wrap(err, "could not scan row into offer struct")
return 0, errors.Wrap(err, "could not scan row into offer struct")
}

if err = callback(offer); err != nil {
return err
return 0, err
}
lastId = offer.OfferID
}

return rows.Err()

return lastId, rows.Err()
}

// GetUpdatedOffers returns all offers created, updated, or deleted after the given ledger sequence.
45 changes: 38 additions & 7 deletions services/horizon/internal/db2/history/offers_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package history

import (
"context"
"database/sql"
"github.com/stretchr/testify/assert"
"strconv"
"testing"

@@ -105,14 +108,42 @@ func TestGetNonExistentOfferByID(t *testing.T) {
tt.Assert.True(q.NoRows(err))
}

func streamAllOffersInTx(q *Q, ctx context.Context, f func(offer Offer) error) error {
err := q.BeginTx(&sql.TxOptions{ReadOnly: true, Isolation: sql.LevelRepeatableRead})
if err != nil {
return err
}
defer q.Rollback()
return q.StreamAllOffers(ctx, f)
}

func TestStreamAllOffersRequiresTx(t *testing.T) {
tt := test.Start(t)
defer tt.Finish()
test.ResetHorizonDB(t, tt.HorizonDB)
q := &Q{tt.HorizonSession()}

err := q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
return nil
})
assert.EqualError(t, err, "cannot be called outside of a transaction")

assert.NoError(t, q.Begin())
defer q.Rollback()
err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
return nil
})
assert.EqualError(t, err, "should only be called in a repeatable read transaction")
}

func TestQueryEmptyOffers(t *testing.T) {
tt := test.Start(t)
defer tt.Finish()
test.ResetHorizonDB(t, tt.HorizonDB)
q := &Q{tt.HorizonSession()}

var offers []Offer
err := q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
err := streamAllOffersInTx(q, tt.Ctx, func(offer Offer) error {
offers = append(offers, offer)
return nil
})
@@ -150,7 +181,7 @@ func TestInsertOffers(t *testing.T) {
tt.Assert.NoError(err)

var offers []Offer
err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
err = streamAllOffersInTx(q, tt.Ctx, func(offer Offer) error {
offers = append(offers, offer)
return nil
})
@@ -183,7 +214,7 @@ func TestInsertOffers(t *testing.T) {
tt.Assert.Equal(3, afterCompactionCount)

var afterCompactionOffers []Offer
err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
err = streamAllOffersInTx(q, tt.Ctx, func(offer Offer) error {
afterCompactionOffers = append(afterCompactionOffers, offer)
return nil
})
@@ -201,7 +232,7 @@ func TestUpdateOffer(t *testing.T) {
tt.Assert.NoError(err)

var offers []Offer
err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
err = streamAllOffersInTx(q, tt.Ctx, func(offer Offer) error {
offers = append(offers, offer)
return nil
})
@@ -229,7 +260,7 @@ func TestUpdateOffer(t *testing.T) {
tt.Assert.NoError(err)

offers = nil
err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
err = streamAllOffersInTx(q, tt.Ctx, func(offer Offer) error {
offers = append(offers, offer)
return nil
})
@@ -256,7 +287,7 @@ func TestRemoveOffer(t *testing.T) {
err := insertOffer(tt, q, eurOffer)
tt.Assert.NoError(err)
var offers []Offer
err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
err = streamAllOffersInTx(q, tt.Ctx, func(offer Offer) error {
offers = append(offers, offer)
return nil
})
@@ -274,7 +305,7 @@ func TestRemoveOffer(t *testing.T) {
expectedUpdates[0].Deleted = true

offers = nil
err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error {
err = streamAllOffersInTx(q, tt.Ctx, func(offer Offer) error {
offers = append(offers, offer)
return nil
})
103 changes: 77 additions & 26 deletions services/horizon/internal/db2/history/trade_aggregation.go
Original file line number Diff line number Diff line change
@@ -45,6 +45,8 @@ type TradeAggregation struct {
CloseD int64 `db:"close_d"`
}

const HistoryTradesTableName = "history_trades_60000"

// TradeAggregationsQ is a helper struct to aid in configuring queries to
// bucket and aggregate trades
type TradeAggregationsQ struct {
@@ -123,51 +125,100 @@ func (q *TradeAggregationsQ) WithEndTime(endTime strtime.Millis) (*TradeAggregat
}
}

// GetSql generates a sql statement to aggregate Trades based on given parameters
func (q *TradeAggregationsQ) GetSql() sq.SelectBuilder {
var orderPreserved bool
orderPreserved, q.baseAssetID, q.counterAssetID = getCanonicalAssetOrder(q.baseAssetID, q.counterAssetID)

var bucketSQL sq.SelectBuilder
func (q *TradeAggregationsQ) getRawTradesSql(orderPreserved bool) sq.SelectBuilder {
var rawTradesSQL sq.SelectBuilder
if orderPreserved {
bucketSQL = bucketTrades(q.resolution, q.offset)
rawTradesSQL = bucketTrades(q.resolution, q.offset)
} else {
bucketSQL = reverseBucketTrades(q.resolution, q.offset)
rawTradesSQL = reverseBucketTrades(q.resolution, q.offset)
}

bucketSQL = bucketSQL.From("history_trades_60000").
rawTradesSQL = rawTradesSQL.
Join("timestamp_range r ON 1=1").
From(fmt.Sprintf("%s AS tr", HistoryTradesTableName)).
Where(sq.Eq{"base_asset_id": q.baseAssetID, "counter_asset_id": q.counterAssetID})

//adjust time range and apply time filters
bucketSQL = bucketSQL.Where(sq.GtOrEq{"timestamp": q.startTime})
if !q.endTime.IsNil() {
bucketSQL = bucketSQL.Where(sq.Lt{"timestamp": q.endTime})
}
bucketTs := formatBucketTimestamp(q.resolution, q.offset, "tr")
rawTradesSQL = rawTradesSQL.
Where(fmt.Sprintf("r.max_ts >= %s", bucketTs)).
Where(fmt.Sprintf("r.min_ts <= %s", bucketTs))

if q.resolution != 60000 {
//ensure open/close order for cases when multiple trades occur in the same ledger
bucketSQL = bucketSQL.OrderBy("timestamp ASC", "open_ledger_toid ASC")
rawTradesSQL = rawTradesSQL.OrderBy("timestamp ASC", "open_ledger_toid ASC")
// Do on-the-fly aggregation for higher resolutions.
bucketSQL = aggregate(bucketSQL)
}
return rawTradesSQL
}

// GetSql generates a sql statement to aggregate Trades based on given parameters
func (q *TradeAggregationsQ) GetSql() sq.SelectBuilder {
var orderPreserved bool
orderPreserved, q.baseAssetID, q.counterAssetID = getCanonicalAssetOrder(q.baseAssetID, q.counterAssetID)

return bucketSQL.
bucketSQL := aggregate("raw_trades").
Limit(q.pagingParams.Limit).
OrderBy("timestamp " + q.pagingParams.Order)
OrderBy("timestamp "+q.pagingParams.Order).
Prefix("WITH last_range_ts AS (?),",
lastRangeTs(
q.baseAssetID, q.counterAssetID, q.resolution, q.offset, q.startTime, q.endTime,
q.pagingParams.Order, q.pagingParams.Limit)).
Prefix("timestamp_range AS (?),",
timestampRange()).
Prefix("raw_trades AS (?)",
q.getRawTradesSql(orderPreserved))

return bucketSQL
}

// formatBucketTimestampSelect formats a sql select clause for a bucketed timestamp, based on given resolution
// formatBucketTimestamp formats a sql select clause for a bucketed timestamp, based on given resolution
// and the offset. Given a time t, it gives it a timestamp defined by
// f(t) = ((t - offset)/resolution)*resolution + offset.
func formatBucketTimestampSelect(resolution int64, offset int64) string {
return fmt.Sprintf("((timestamp - %d) / %d) * %d + %d as timestamp", offset, resolution, resolution, offset)
func formatBucketTimestamp(resolution int64, offset int64, tsPrefix string) string {
prefix := ""
if len(tsPrefix) > 0 {
prefix = fmt.Sprintf("%s.", tsPrefix)
}
return fmt.Sprintf("((%stimestamp - %d) / %d) * %d + %d", prefix, offset, resolution, resolution, offset)
}

func formatBucketTimestampSelect(resolution int64, offset int64, tsPrefix string) string {
return fmt.Sprintf("%s AS timestamp", formatBucketTimestamp(resolution, offset, tsPrefix))
}

func lastRangeTs(baseAssetID, counterAssetID, resolution, offset int64, startTime, endTime strtime.Millis, order string, limit uint64) sq.SelectBuilder {
s := sq.Select(
formatBucketTimestampSelect(resolution, offset, ""),
).From(
HistoryTradesTableName,
).Where(
sq.Eq{"base_asset_id": baseAssetID, "counter_asset_id": counterAssetID},
).Where(sq.GtOrEq{"timestamp": startTime})
if !endTime.IsNil() {
s = s.Where(sq.Lt{"timestamp": endTime})
}
return s.GroupBy(
formatBucketTimestamp(resolution, offset, ""),
).OrderBy(
fmt.Sprintf("%s %s", formatBucketTimestamp(resolution, offset, ""), order),
).Suffix(
fmt.Sprintf("FETCH FIRST %d ROWS ONLY", limit),
)
}

func timestampRange() sq.SelectBuilder {
return sq.Select(
"min(timestamp) as min_ts",
"max(timestamp) as max_ts",
).From("last_range_ts")
}

// bucketTrades generates a select statement to filter rows from the `history_trades` table in
// a compact form, with a timestamp rounded to resolution and reversed base/counter.
func bucketTrades(resolution int64, offset int64) sq.SelectBuilder {
return sq.Select(
formatBucketTimestampSelect(resolution, offset),
formatBucketTimestampSelect(resolution, offset, "tr"),
"count",
"base_volume",
"counter_volume",
@@ -187,7 +238,7 @@ func bucketTrades(resolution int64, offset int64) sq.SelectBuilder {
// a compact form, with a timestamp rounded to resolution and reversed base/counter.
func reverseBucketTrades(resolution int64, offset int64) sq.SelectBuilder {
return sq.Select(
formatBucketTimestampSelect(resolution, offset),
formatBucketTimestampSelect(resolution, offset, "tr"),
"count",
"base_volume as counter_volume",
"counter_volume as base_volume",
@@ -203,7 +254,7 @@ func reverseBucketTrades(resolution int64, offset int64) sq.SelectBuilder {
)
}

func aggregate(query sq.SelectBuilder) sq.SelectBuilder {
func aggregate(rawTradesTable string) sq.SelectBuilder {
return sq.Select(
"timestamp",
"sum(\"count\") as count",
@@ -218,7 +269,7 @@ func aggregate(query sq.SelectBuilder) sq.SelectBuilder {
"(first(ARRAY[open_n, open_d]))[2] as open_d",
"(last(ARRAY[close_n, close_d]))[1] as close_n",
"(last(ARRAY[close_n, close_d]))[2] as close_d",
).FromSelect(query, "htrd").GroupBy("timestamp")
).From(rawTradesTable).GroupBy("timestamp")
}

// RebuildTradeAggregationTimes rebuilds a specific set of trade aggregation
@@ -228,7 +279,7 @@ func (q Q) RebuildTradeAggregationTimes(ctx context.Context, from, to strtime.Mi
from = from.RoundDown(60_000)
to = to.RoundDown(60_000)
// Clear out the old bucket values.
_, err := q.Exec(ctx, sq.Delete("history_trades_60000").Where(
_, err := q.Exec(ctx, sq.Delete(HistoryTradesTableName).Where(
sq.GtOrEq{"timestamp": from},
).Where(
sq.LtOrEq{"timestamp": to},
@@ -278,7 +329,7 @@ func (q Q) RebuildTradeAggregationTimes(ctx context.Context, from, to strtime.Mi
).FromSelect(trades, "trades").GroupBy("base_asset_id", "counter_asset_id", "timestamp")

// Insert the new bucket values.
_, err = q.Exec(ctx, sq.Insert("history_trades_60000").Select(rebuilt))
_, err = q.Exec(ctx, sq.Insert(HistoryTradesTableName).Select(rebuilt))
if err != nil {
return errors.Wrap(err, "could not rebuild trade aggregation bucket")
}
8 changes: 8 additions & 0 deletions services/horizon/internal/flags.go
Original file line number Diff line number Diff line change
@@ -400,6 +400,14 @@ func Flags() (*Config, support.ConfigOptions) {
Required: false,
Usage: "excludes liquidity pools from consideration in the `/paths` endpoint",
},
&support.ConfigOption{
Name: "disable-path-finding",
ConfigKey: &config.DisablePathFinding,
OptType: types.Bool,
FlagDefault: false,
Required: false,
Usage: "disables the path finding endpoints",
},
&support.ConfigOption{
Name: "max-path-finding-requests",
ConfigKey: &config.MaxPathFindingRequests,
34 changes: 18 additions & 16 deletions services/horizon/internal/httpx/router.go
Original file line number Diff line number Diff line change
@@ -197,22 +197,24 @@ func (r *Router) addRoutes(config *RouterConfig, rateLimiter *throttled.HTTPRate

r.With(stateMiddleware.Wrap).Method(http.MethodGet, "/assets", restPageHandler(ledgerState, actions.AssetStatsHandler{LedgerState: ledgerState}))

findPaths := ObjectActionHandler{actions.FindPathsHandler{
StaleThreshold: config.StaleThreshold,
SetLastLedgerHeader: true,
MaxPathLength: config.MaxPathLength,
MaxAssetsParamLength: config.MaxAssetsPerPathRequest,
PathFinder: config.PathFinder,
}}
findFixedPaths := ObjectActionHandler{actions.FindFixedPathsHandler{
MaxPathLength: config.MaxPathLength,
SetLastLedgerHeader: true,
MaxAssetsParamLength: config.MaxAssetsPerPathRequest,
PathFinder: config.PathFinder,
}}
r.With(stateMiddleware.Wrap).Method(http.MethodGet, "/paths", findPaths)
r.With(stateMiddleware.Wrap).Method(http.MethodGet, "/paths/strict-receive", findPaths)
r.With(stateMiddleware.Wrap).Method(http.MethodGet, "/paths/strict-send", findFixedPaths)
if config.PathFinder != nil {
findPaths := ObjectActionHandler{actions.FindPathsHandler{
StaleThreshold: config.StaleThreshold,
SetLastLedgerHeader: true,
MaxPathLength: config.MaxPathLength,
MaxAssetsParamLength: config.MaxAssetsPerPathRequest,
PathFinder: config.PathFinder,
}}
findFixedPaths := ObjectActionHandler{actions.FindFixedPathsHandler{
MaxPathLength: config.MaxPathLength,
SetLastLedgerHeader: true,
MaxAssetsParamLength: config.MaxAssetsPerPathRequest,
PathFinder: config.PathFinder,
}}
r.With(stateMiddleware.Wrap).Method(http.MethodGet, "/paths", findPaths)
r.With(stateMiddleware.Wrap).Method(http.MethodGet, "/paths/strict-receive", findPaths)
r.With(stateMiddleware.Wrap).Method(http.MethodGet, "/paths/strict-send", findFixedPaths)
}
r.With(stateMiddleware.Wrap).Method(
http.MethodGet,
"/order_book",
1 change: 1 addition & 0 deletions services/horizon/internal/httpx/server.go
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ func init() {
problem.RegisterError(context.Canceled, hProblem.ClientDisconnected)
problem.RegisterError(db.ErrCancelled, hProblem.ClientDisconnected)
problem.RegisterError(db.ErrTimeout, hProblem.ServiceUnavailable)
problem.RegisterError(db.ErrStatementTimeout, hProblem.ServiceUnavailable)
problem.RegisterError(db.ErrConflictWithRecovery, hProblem.ServiceUnavailable)
problem.RegisterError(db.ErrBadConnection, hProblem.ServiceUnavailable)
}
1 change: 0 additions & 1 deletion services/horizon/internal/ingest/orderbook.go
Original file line number Diff line number Diff line change
@@ -136,7 +136,6 @@ func (o *OrderBookStream) update(ctx context.Context, status ingestionStatus) (b
o.graph.AddOffers(offerToXDR(offer))
return nil
})

if err != nil {
return true, errors.Wrap(err, "Error loading offers into orderbook")
}
26 changes: 23 additions & 3 deletions services/horizon/internal/init.go
Original file line number Diff line number Diff line change
@@ -18,9 +18,9 @@ import (
"github.com/stellar/go/support/log"
)

func mustNewDBSession(subservice db.Subservice, databaseURL string, maxIdle, maxOpen int, registry *prometheus.Registry) db.SessionInterface {
func mustNewDBSession(subservice db.Subservice, databaseURL string, maxIdle, maxOpen int, registry *prometheus.Registry, clientConfigs ...db.ClientConfig) db.SessionInterface {
log.Infof("Establishing database session for %v", subservice)
session, err := db.Open("postgres", databaseURL)
session, err := db.Open("postgres", databaseURL, clientConfigs...)
if err != nil {
log.Fatalf("cannot open %v DB: %v", subservice, err)
}
@@ -47,21 +47,36 @@ func mustInitHorizonDB(app *App) {
}

if app.config.RoDatabaseURL == "" {
var clientConfigs []db.ClientConfig
if !app.config.Ingest {
// if we are not ingesting then we don't expect to have long db queries / transactions
clientConfigs = append(
clientConfigs,
db.StatementTimeout(app.config.ConnectionTimeout),
db.IdleTransactionTimeout(app.config.ConnectionTimeout),
)
}
app.historyQ = &history.Q{mustNewDBSession(
db.HistorySubservice,
app.config.DatabaseURL,
maxIdle,
maxOpen,
app.prometheusRegistry,
clientConfigs...,
)}
} else {
// If RO set, use it for all DB queries
roClientConfigs := []db.ClientConfig{
db.StatementTimeout(app.config.ConnectionTimeout),
db.IdleTransactionTimeout(app.config.ConnectionTimeout),
}
app.historyQ = &history.Q{mustNewDBSession(
db.HistorySubservice,
app.config.RoDatabaseURL,
maxIdle,
maxOpen,
app.prometheusRegistry,
roClientConfigs...,
)}

app.primaryHistoryQ = &history.Q{mustNewDBSession(
@@ -112,6 +127,9 @@ func initIngester(app *App) {
}

func initPathFinder(app *App) {
if app.config.DisablePathFinding {
return
}
orderBookGraph := orderbook.NewOrderBookGraph()
app.orderBookStream = ingest.NewOrderBookStream(
&history.Q{app.HorizonSession()},
@@ -178,7 +196,9 @@ func initDbMetrics(app *App) {

app.coreState.RegisterMetrics(app.prometheusRegistry)

app.prometheusRegistry.MustRegister(app.orderBookStream.LatestLedgerGauge)
if !app.config.DisablePathFinding {
app.prometheusRegistry.MustRegister(app.orderBookStream.LatestLedgerGauge)
}
}

// initGoMetrics registers the Go collector provided by prometheus package which
21 changes: 21 additions & 0 deletions services/horizon/internal/integration/parameters_test.go
Original file line number Diff line number Diff line change
@@ -190,6 +190,27 @@ func TestMaxPathFindingRequests(t *testing.T) {
})
}

func TestDisablePathFinding(t *testing.T) {
t.Run("default", func(t *testing.T) {
test := NewParameterTest(t, map[string]string{})
err := test.StartHorizon()
assert.NoError(t, err)
test.WaitForHorizon()
assert.Equal(t, test.Horizon().Config().MaxPathFindingRequests, uint(0))
_, ok := test.Horizon().Paths().(simplepath.InMemoryFinder)
assert.True(t, ok)
test.Shutdown()
})
t.Run("set to true", func(t *testing.T) {
test := NewParameterTest(t, map[string]string{"disable-path-finding": "true"})
err := test.StartHorizon()
assert.NoError(t, err)
test.WaitForHorizon()
assert.Nil(t, test.Horizon().Paths())
test.Shutdown()
})
}

// Pattern taken from testify issue:
// https://github.com/stretchr/testify/issues/858#issuecomment-600491003
//
59 changes: 58 additions & 1 deletion support/db/main.go
Original file line number Diff line number Diff line change
@@ -14,6 +14,9 @@ package db
import (
"context"
"database/sql"
"net/url"
"strconv"
"strings"
"time"

"github.com/Masterminds/squirrel"
@@ -44,6 +47,9 @@ var (
// ErrBadConnection is an error returned when driver returns `bad connection`
// error.
ErrBadConnection = errors.New("bad connection")
// ErrStatementTimeout is an error returned by Session methods when request has
// been cancelled due to a statement timeout.
ErrStatementTimeout = errors.New("canceling statement due to statement timeout")
)

// Conn represents a connection to a single database.
@@ -163,8 +169,59 @@ func pingDB(db *sqlx.DB) error {
return errors.Wrapf(err, "failed to connect to DB after %v attempts", maxDBPingAttempts)
}

type ClientConfig struct {
Key string
Value string
}

func StatementTimeout(timeout time.Duration) ClientConfig {
return ClientConfig{
Key: "statement_timeout",
Value: strconv.FormatInt(timeout.Milliseconds(), 10),
}
}

func IdleTransactionTimeout(timeout time.Duration) ClientConfig {
return ClientConfig{
Key: "idle_in_transaction_session_timeout",
Value: strconv.FormatInt(timeout.Milliseconds(), 10),
}
}

func augmentDSN(dsn string, clientConfigs []ClientConfig) string {
parsed, err := url.Parse(dsn)
// dsn can either be a postgres url like "postgres://postgres:123456@127.0.0.1:5432"
// or, it can be a white space separated string of key value pairs like
// "host=localhost port=5432 user=bob password=secret"
if err != nil || parsed.Scheme == "" {
// if dsn does not parse as a postgres url, we assume it must be take
// the form of a white space separated string
parts := []string{dsn}
for _, config := range clientConfigs {
// do not override if the key is already present in dsn
if strings.Contains(dsn, config.Key+"=") {
continue
}
parts = append(parts, config.Key+"="+config.Value)
}
return strings.Join(parts, " ")
}

q := parsed.Query()
for _, config := range clientConfigs {
// do not override if the key is already present in dsn
if len(q.Get(config.Key)) > 0 {
continue
}
q.Set(config.Key, config.Value)
}
parsed.RawQuery = q.Encode()
return parsed.String()
}

// Open the database at `dsn` and returns a new *Session using it.
func Open(dialect, dsn string) (*Session, error) {
func Open(dialect, dsn string, clientConfigs ...ClientConfig) (*Session, error) {
dsn = augmentDSN(dsn, clientConfigs)
db, err := sqlx.Open(dialect, dsn)
if err != nil {
return nil, errors.Wrap(err, "open failed")
25 changes: 25 additions & 0 deletions support/db/main_test.go
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ package db

import (
"testing"
"time"

"github.com/stellar/go/support/db/dbtest"
"github.com/stretchr/testify/assert"
@@ -27,3 +28,27 @@ func TestGetTable(t *testing.T) {
}

}

func TestAugmentDSN(t *testing.T) {
configs := []ClientConfig{
IdleTransactionTimeout(2 * time.Second),
StatementTimeout(4 * time.Millisecond),
}
for _, testCase := range []struct {
input string
expected string
}{
{"postgresql://localhost", "postgresql://localhost?idle_in_transaction_session_timeout=2000&statement_timeout=4"},
{"postgresql://localhost/mydb?user=other&password=secret", "postgresql://localhost/mydb?idle_in_transaction_session_timeout=2000&password=secret&statement_timeout=4&user=other"},
{"postgresql://localhost/mydb?user=other&idle_in_transaction_session_timeout=500", "postgresql://localhost/mydb?idle_in_transaction_session_timeout=500&statement_timeout=4&user=other"},
{"host=localhost user=bob password=secret", "host=localhost user=bob password=secret idle_in_transaction_session_timeout=2000 statement_timeout=4"},
{"host=localhost user=bob password=secret statement_timeout=32", "host=localhost user=bob password=secret statement_timeout=32 idle_in_transaction_session_timeout=2000"},
} {
t.Run(testCase.input, func(t *testing.T) {
output := augmentDSN(testCase.input, configs)
if output != testCase.expected {
t.Fatalf("got %v but expected %v", output, testCase.expected)
}
})
}
}
14 changes: 14 additions & 0 deletions support/db/session.go
Original file line number Diff line number Diff line change
@@ -91,6 +91,10 @@ func (s *Session) Commit() error {
log.Debug("sql: commit")
s.tx = nil
s.txOptions = nil

if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil {
return knownErr
}
return err
}

@@ -231,6 +235,10 @@ func (s *Session) NoRows(err error) bool {
// replaceWithKnownError tries to replace Postgres error with package error.
// Returns a new error if the err is known.
func (s *Session) replaceWithKnownError(err error, ctx context.Context) error {
if err == nil {
return nil
}

switch {
case ctx.Err() == context.Canceled:
return ErrCancelled
@@ -243,6 +251,8 @@ func (s *Session) replaceWithKnownError(err error, ctx context.Context) error {
return ErrConflictWithRecovery
case strings.Contains(err.Error(), "driver: bad connection"):
return ErrBadConnection
case strings.Contains(err.Error(), "pq: canceling statement due to statement timeout"):
return ErrStatementTimeout
default:
return nil
}
@@ -305,6 +315,10 @@ func (s *Session) Rollback() error {
log.Debug("sql: rollback")
s.tx = nil
s.txOptions = nil

if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil {
return knownErr
}
return err
}

31 changes: 31 additions & 0 deletions support/db/session_test.go
Original file line number Diff line number Diff line change
@@ -129,3 +129,34 @@ func TestSession(t *testing.T) {
assert.Equal("$1 = $2 = $3 = ?", out)
}
}

func TestStatementTimeout(t *testing.T) {
assert := assert.New(t)
db := dbtest.Postgres(t).Load(testSchema)
defer db.Close()

sess, err := Open(db.Dialect, db.DSN, StatementTimeout(50*time.Millisecond))
assert.NoError(err)
defer sess.Close()

var count int
err = sess.GetRaw(context.Background(), &count, "SELECT pg_sleep(2), COUNT(*) FROM people")
assert.ErrorIs(err, ErrStatementTimeout)
}

func TestIdleTransactionTimeout(t *testing.T) {
assert := assert.New(t)
db := dbtest.Postgres(t).Load(testSchema)
defer db.Close()

sess, err := Open(db.Dialect, db.DSN, IdleTransactionTimeout(50*time.Millisecond))
assert.NoError(err)
defer sess.Close()

assert.NoError(sess.Begin())
<-time.After(100 * time.Millisecond)

var count int
err = sess.GetRaw(context.Background(), &count, "SELECT COUNT(*) FROM people")
assert.ErrorIs(err, ErrBadConnection)
}