Skip to content

Commit

Permalink
Dbname issue (#273)
Browse files Browse the repository at this point in the history
* Added sample apps code

* appended host as a subsegment name in the case of known DSL

* fix file format

* gofmt
  • Loading branch information
bhautikpip authored Jan 26, 2021
1 parent ca90926 commit bd68eed
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 33 deletions.
28 changes: 19 additions & 9 deletions xray/sql_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net"
"net/url"
"reflect"
"strconv"
Expand Down Expand Up @@ -106,7 +107,7 @@ type driverConn struct {
}

func (conn *driverConn) Ping(ctx context.Context) error {
return Capture(ctx, conn.attr.dbname, func(ctx context.Context) error {
return Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
conn.attr.populate(ctx, "PING")
if p, ok := conn.Conn.(driver.Pinger); ok {
return p.Ping(ctx)
Expand Down Expand Up @@ -191,7 +192,7 @@ func (conn *driverConn) ExecContext(ctx context.Context, query string, args []dr
var err error
var result driver.Result
if execerCtx, ok := conn.Conn.(driver.ExecerContext); ok {
Capture(ctx, conn.attr.dbname, func(ctx context.Context) error {
Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
result, err = execerCtx.ExecContext(ctx, query, args)
if err == driver.ErrSkip {
conn.attr.populate(ctx, query+msgErrSkip)
Expand All @@ -210,7 +211,7 @@ func (conn *driverConn) ExecContext(ctx context.Context, query string, args []dr
if err0 != nil {
return nil, err0
}
Capture(ctx, conn.attr.dbname, func(ctx context.Context) error {
Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
var err error
result, err = execer.Exec(query, dargs)
if err == driver.ErrSkip {
Expand All @@ -237,7 +238,7 @@ func (conn *driverConn) QueryContext(ctx context.Context, query string, args []d
var err error
var rows driver.Rows
if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok {
Capture(ctx, conn.attr.dbname, func(ctx context.Context) error {
Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
rows, err = queryerCtx.QueryContext(ctx, query, args)
if err == driver.ErrSkip {
conn.attr.populate(ctx, query+msgErrSkip)
Expand All @@ -256,7 +257,7 @@ func (conn *driverConn) QueryContext(ctx context.Context, query string, args []d
if err0 != nil {
return nil, err0
}
err = Capture(ctx, conn.attr.dbname, func(ctx context.Context) error {
err = Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
rows, err = queryer.Query(query, dargs)
if err == driver.ErrSkip {
conn.attr.populate(ctx, query+msgErrSkip)
Expand Down Expand Up @@ -301,6 +302,7 @@ type dbAttribute struct {
driverVersion string
user string
dbname string
host string
}

func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, conn driver.Conn, dsn string, filtered bool) (*dbAttribute, error) {
Expand Down Expand Up @@ -341,6 +343,14 @@ func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, con
q.Del("password")
u.RawQuery = q.Encode()

// In the case of known DSL sub segment name will be dbname@host
host, _, _ := net.SplitHostPort(u.Host)
if len(host) > 0 {
attr.host = "@" + host
} else {
attr.host = host
}

attr.url = u.String()
if !strings.Contains(dsn, "//") {
attr.url = attr.url[2:]
Expand Down Expand Up @@ -562,7 +572,7 @@ func (stmt *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValu
var result driver.Result
var err error
if execerContext, ok := stmt.Stmt.(driver.StmtExecContext); ok {
err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = execerContext.ExecContext(ctx, args)
Expand All @@ -578,7 +588,7 @@ func (stmt *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValu
if err0 != nil {
return nil, err0
}
err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = stmt.Stmt.Exec(dargs)
Expand All @@ -599,7 +609,7 @@ func (stmt *driverStmt) QueryContext(ctx context.Context, args []driver.NamedVal
var result driver.Rows
var err error
if queryCtx, ok := stmt.Stmt.(driver.StmtQueryContext); ok {
err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = queryCtx.QueryContext(ctx, args)
Expand All @@ -615,7 +625,7 @@ func (stmt *driverStmt) QueryContext(ctx context.Context, args []driver.NamedVal
if err0 != nil {
return nil, err0
}
err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = stmt.Stmt.Query(dargs)
Expand Down
2 changes: 1 addition & 1 deletion xray/sql_go110.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (c *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
if err != nil {
return nil, err
}
err = Capture(ctx, attr.dbname, func(ctx context.Context) error {
err = Capture(ctx, attr.dbname+attr.host, func(ctx context.Context) error {
attr.populate(ctx, "CONNECT")
var err error
rawConn, err = c.Connector.Connect(ctx)
Expand Down
58 changes: 35 additions & 23 deletions xray/sqlcontext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,49 +74,60 @@ func capturePing(dsn string) (*Segment, error) {

func TestDSN(t *testing.T) {
tc := []struct {
dsn string
url string
str string
dsn string
url string
str string
name string
}{
{
dsn: "postgres://user@host:5432/database",
url: "postgres://user@host:5432/database",
dsn: "postgres://user@host:5432/database",
url: "postgres://user@host:5432/database",
name: "test database@host",
},
{
dsn: "postgres://user:password@host:5432/database",
url: "postgres://user@host:5432/database",
dsn: "postgres://user:password@host:5432/database",
url: "postgres://user@host:5432/database",
name: "test database@host",
},
{
dsn: "postgres://host:5432/database?password=password",
url: "postgres://host:5432/database",
dsn: "postgres://host:5432/database?password=password",
url: "postgres://host:5432/database",
name: "test database@host",
},
{
dsn: "user:password@host:5432/database",
url: "user@host:5432/database",
dsn: "user:password@host:5432/database",
url: "user@host:5432/database",
name: "test database@host",
},
{
dsn: "host:5432/database?password=password",
url: "host:5432/database",
dsn: "host:5432/database?password=password",
url: "host:5432/database",
name: "test database@host",
},
{
dsn: "user%2Fpassword@host:5432/database",
url: "user@host:5432/database",
dsn: "user%2Fpassword@host:5432/database",
url: "user@host:5432/database",
name: "test database@host",
},
{
dsn: "user/password@host:5432/database",
url: "user@host:5432/database",
dsn: "user/password@host:5432/database",
url: "user@host:5432/database",
name: "test database@host",
},
{
dsn: "user=user database=database",
str: "user=user database=database",
dsn: "user=user database=database",
str: "user=user database=database",
name: "test database",
},
{
dsn: "user=user password=password database=database",
str: "user=user database=database",
dsn: "user=user password=password database=database",
str: "user=user database=database",
name: "test database",
},
{
dsn: "odbc:server=localhost;user id=sa;password={foo}};bar};otherthing=thing",
str: "odbc:server=localhost;user id=sa;otherthing=thing",
dsn: "odbc:server=localhost;user id=sa;password={foo}};bar};otherthing=thing",
str: "odbc:server=localhost;user id=sa;otherthing=thing",
name: "test database",
},
}

Expand All @@ -142,6 +153,7 @@ func TestDSN(t *testing.T) {
assert.Equal(t, tt.str, subseg.SQL.ConnectionString)
assert.Equal(t, "test version", subseg.SQL.DatabaseVersion)
assert.Equal(t, "test user", subseg.SQL.User)
assert.Equal(t, tt.name, subseg.Name)
assert.False(t, subseg.Throttle)
assert.False(t, subseg.Error)
assert.False(t, subseg.Fault)
Expand Down

0 comments on commit bd68eed

Please sign in to comment.