Skip to content

Commit

Permalink
feat(pgdriver): improve otel instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jan 22, 2025
1 parent 3d8666a commit c40e4f3
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 52 deletions.
91 changes: 60 additions & 31 deletions driver/pgdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"time"

"github.com/uptrace/bun/internal"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
"go.opentelemetry.io/otel/trace"
)

func init() {
Expand Down Expand Up @@ -68,38 +70,38 @@ func (d Driver) Open(name string) (driver.Conn, error) {
//------------------------------------------------------------------------------

type Connector struct {
cfg *Config
conf *Config
}

func NewConnector(opts ...Option) *Connector {
c := &Connector{cfg: newDefaultConfig()}
c := &Connector{conf: newDefaultConfig()}
for _, opt := range opts {
opt(c.cfg)
opt(c.conf)
}
return c
}

var _ driver.Connector = (*Connector)(nil)

func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
if err := c.cfg.verify(); err != nil {
if err := c.conf.verify(); err != nil {
return nil, err
}
return newConn(ctx, c.cfg)
return newConn(ctx, c.conf)
}

func (c *Connector) Driver() driver.Driver {
return Driver{connector: c}
}

func (c *Connector) Config() *Config {
return c.cfg
return c.conf
}

//------------------------------------------------------------------------------

type Conn struct {
cfg *Config
conf *Config

netConn net.Conn
rd *reader
Expand All @@ -112,20 +114,20 @@ type Conn struct {
closed int32
}

func newConn(ctx context.Context, cfg *Config) (*Conn, error) {
netConn, err := cfg.Dialer(ctx, cfg.Network, cfg.Addr)
func newConn(ctx context.Context, conf *Config) (*Conn, error) {
netConn, err := conf.Dialer(ctx, conf.Network, conf.Addr)
if err != nil {
return nil, err
}

cn := &Conn{
cfg: cfg,
conf: conf,
netConn: netConn,
rd: newReader(netConn),
}

if cfg.TLSConfig != nil {
if err := enableSSL(ctx, cn, cfg.TLSConfig); err != nil {
if conf.TLSConfig != nil {
if err := enableSSL(ctx, cn, conf.TLSConfig); err != nil {
return nil, err
}
}
Expand All @@ -134,7 +136,7 @@ func newConn(ctx context.Context, cfg *Config) (*Conn, error) {
return nil, err
}

for k, v := range cfg.ConnParams {
for k, v := range conf.ConnParams {
if v != nil {
_, err = cn.ExecContext(ctx, fmt.Sprintf("SET %s TO $1", k), []driver.NamedValue{
{Value: v},
Expand All @@ -150,6 +152,17 @@ func newConn(ctx context.Context, cfg *Config) (*Conn, error) {
return cn, nil
}

func (cn *Conn) Close() error {
if !atomic.CompareAndSwapInt32(&cn.closed, 0, 1) {
return nil
}
return cn.netConn.Close()
}

func (cn *Conn) isClosed() bool {
return atomic.LoadInt32(&cn.closed) == 1
}

func (cn *Conn) reader(ctx context.Context, timeout time.Duration) *reader {
cn.setReadDeadline(ctx, timeout)
return cn.rd
Expand All @@ -174,11 +187,16 @@ func (cn *Conn) write(ctx context.Context, wb *writeBuffer) error {
var _ driver.Conn = (*Conn)(nil)

func (cn *Conn) Prepare(query string) (driver.Stmt, error) {
return cn.PrepareContext(context.Background(), query)
}

var _ driver.ConnPrepareContext = (*Conn)(nil)

func (cn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if cn.isClosed() {
return nil, driver.ErrBadConn
}

ctx := context.TODO()
cn.trace(ctx)

name := fmt.Sprintf("pgdriver-%d", cn.stmtCount)
cn.stmtCount++
Expand All @@ -195,32 +213,29 @@ func (cn *Conn) Prepare(query string) (driver.Stmt, error) {
return newStmt(cn, name, rowDesc), nil
}

func (cn *Conn) Close() error {
if !atomic.CompareAndSwapInt32(&cn.closed, 0, 1) {
return nil
}
return cn.netConn.Close()
}

func (cn *Conn) isClosed() bool {
return atomic.LoadInt32(&cn.closed) == 1
}

func (cn *Conn) Begin() (driver.Tx, error) {
return cn.BeginTx(context.Background(), driver.TxOptions{})
}

var _ driver.ConnBeginTx = (*Conn)(nil)

func (cn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if cn.isClosed() {
return nil, driver.ErrBadConn
}
cn.trace(ctx)

// No need to check if the conn is closed. ExecContext below handles that.
isolation := sql.IsolationLevel(opts.Isolation)

var command string
switch isolation {
case sql.LevelDefault:
command = "BEGIN"
case sql.LevelReadUncommitted, sql.LevelReadCommitted, sql.LevelRepeatableRead, sql.LevelSerializable:
case sql.LevelReadUncommitted,
sql.LevelReadCommitted,
sql.LevelRepeatableRead,
sql.LevelSerializable:
command = fmt.Sprintf("BEGIN; SET TRANSACTION ISOLATION LEVEL %s", isolation.String())
default:
return nil, fmt.Errorf("pgdriver: unsupported transaction isolation: %s", isolation.String())
Expand All @@ -244,6 +259,8 @@ func (cn *Conn) ExecContext(
if cn.isClosed() {
return nil, driver.ErrBadConn
}
cn.trace(ctx)

res, err := cn.exec(ctx, query, args)
if err != nil {
return nil, cn.checkBadConn(err)
Expand Down Expand Up @@ -272,6 +289,8 @@ func (cn *Conn) QueryContext(
if cn.isClosed() {
return nil, driver.ErrBadConn
}
cn.trace(ctx)

rows, err := cn.query(ctx, query, args)
if err != nil {
return nil, cn.checkBadConn(err)
Expand Down Expand Up @@ -301,14 +320,14 @@ func (cn *Conn) Ping(ctx context.Context) error {

func (cn *Conn) setReadDeadline(ctx context.Context, timeout time.Duration) {
if timeout == -1 {
timeout = cn.cfg.ReadTimeout
timeout = cn.conf.ReadTimeout
}
_ = cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
}

func (cn *Conn) setWriteDeadline(ctx context.Context, timeout time.Duration) {
if timeout == -1 {
timeout = cn.cfg.WriteTimeout
timeout = cn.conf.WriteTimeout
}
_ = cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
}
Expand Down Expand Up @@ -343,8 +362,8 @@ func (cn *Conn) ResetSession(ctx context.Context) error {
if cn.isClosed() {
return driver.ErrBadConn
}
if cn.cfg.ResetSessionFunc != nil {
return cn.cfg.ResetSessionFunc(ctx, cn)
if cn.conf.ResetSessionFunc != nil {
return cn.conf.ResetSessionFunc(ctx, cn)
}
return nil
}
Expand All @@ -360,6 +379,16 @@ func (cn *Conn) checkBadConn(err error) error {

func (cn *Conn) Conn() net.Conn { return cn.netConn }

func (cn *Conn) trace(ctx context.Context) {
if span := trace.SpanFromContext(ctx); span.IsRecording() {
span.SetAttributes(
semconv.DBUserKey.String(cn.conf.User),
semconv.DBNameKey.String(cn.conf.Database),
semconv.ServerAddressKey.String(cn.conf.Addr),
)
}
}

//------------------------------------------------------------------------------

type rows struct {
Expand Down
5 changes: 3 additions & 2 deletions driver/pgdriver/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ go 1.22.0
replace github.com/uptrace/bun => ../..

require (
github.com/stretchr/testify v1.8.1
github.com/stretchr/testify v1.10.0
github.com/uptrace/bun v1.2.8
go.opentelemetry.io/otel v1.34.0
go.opentelemetry.io/otel/trace v1.34.0
mellium.im/sasl v0.3.2
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/kr/text v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/puzpuzpuz/xsync/v3 v3.4.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
Expand Down
22 changes: 10 additions & 12 deletions driver/pgdriver/go.sum
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
mellium.im/sasl v0.3.2 h1:PT6Xp7ccn9XaXAnJ03FcEjmAn7kK1x7aoXV6F+Vmrl0=
Expand Down
14 changes: 7 additions & 7 deletions driver/pgdriver/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ func writeStartup(ctx context.Context, cn *Conn) error {
wb.StartMessage(0)
wb.WriteInt32(196608)
wb.WriteString("user")
wb.WriteString(cn.cfg.User)
wb.WriteString(cn.conf.User)
wb.WriteString("database")
wb.WriteString(cn.cfg.Database)
if cn.cfg.AppName != "" {
wb.WriteString(cn.conf.Database)
if cn.conf.AppName != "" {
wb.WriteString("application_name")
wb.WriteString(cn.cfg.AppName)
wb.WriteString(cn.conf.AppName)
}
wb.WriteString("")
wb.FinishMessage()
Expand Down Expand Up @@ -239,7 +239,7 @@ func auth(ctx context.Context, cn *Conn, rd *reader) error {
}

func authCleartext(ctx context.Context, cn *Conn, rd *reader) error {
if err := writePassword(ctx, cn, cn.cfg.Password); err != nil {
if err := writePassword(ctx, cn, cn.conf.Password); err != nil {
return err
}
return readAuthOK(cn, rd)
Expand Down Expand Up @@ -280,7 +280,7 @@ func authMD5(ctx context.Context, cn *Conn, rd *reader) error {
return err
}

secret := "md5" + md5s(md5s(cn.cfg.Password+cn.cfg.User)+string(b))
secret := "md5" + md5s(md5s(cn.conf.Password+cn.conf.User)+string(b))
if err := writePassword(ctx, cn, secret); err != nil {
return err
}
Expand Down Expand Up @@ -329,7 +329,7 @@ loop:
}

creds := sasl.Credentials(func() (Username, Password, Identity []byte) {
return []byte(cn.cfg.User), []byte(cn.cfg.Password), nil
return []byte(cn.conf.User), []byte(cn.conf.Password), nil
})
client := sasl.NewClient(saslMech, creds)

Expand Down

0 comments on commit c40e4f3

Please sign in to comment.