From bd68eeda041f0f8057f4cf4c7feff4617c74c83a Mon Sep 17 00:00:00 2001 From: Bhautik Pipaliya <56270044+bhautikpip@users.noreply.github.com> Date: Tue, 26 Jan 2021 15:24:06 -0800 Subject: [PATCH] Dbname issue (#273) * Added sample apps code * appended host as a subsegment name in the case of known DSL * fix file format * gofmt --- xray/sql_context.go | 28 +++++++++++++------- xray/sql_go110.go | 2 +- xray/sqlcontext_test.go | 58 +++++++++++++++++++++++++---------------- 3 files changed, 55 insertions(+), 33 deletions(-) diff --git a/xray/sql_context.go b/xray/sql_context.go index e7a3f2e7..2f878164 100644 --- a/xray/sql_context.go +++ b/xray/sql_context.go @@ -15,6 +15,7 @@ import ( "database/sql/driver" "errors" "fmt" + "net" "net/url" "reflect" "strconv" @@ -106,7 +107,7 @@ type driverConn struct { } func (conn *driverConn) Ping(ctx context.Context) error { - return Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + return Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error { conn.attr.populate(ctx, "PING") if p, ok := conn.Conn.(driver.Pinger); ok { return p.Ping(ctx) @@ -191,7 +192,7 @@ func (conn *driverConn) ExecContext(ctx context.Context, query string, args []dr var err error var result driver.Result if execerCtx, ok := conn.Conn.(driver.ExecerContext); ok { - Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error { result, err = execerCtx.ExecContext(ctx, query, args) if err == driver.ErrSkip { conn.attr.populate(ctx, query+msgErrSkip) @@ -210,7 +211,7 @@ func (conn *driverConn) ExecContext(ctx context.Context, query string, args []dr if err0 != nil { return nil, err0 } - Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error { var err error result, err = execer.Exec(query, dargs) if err == driver.ErrSkip { @@ -237,7 +238,7 @@ func (conn *driverConn) QueryContext(ctx context.Context, query string, args []d var err error var rows driver.Rows if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok { - Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error { rows, err = queryerCtx.QueryContext(ctx, query, args) if err == driver.ErrSkip { conn.attr.populate(ctx, query+msgErrSkip) @@ -256,7 +257,7 @@ func (conn *driverConn) QueryContext(ctx context.Context, query string, args []d if err0 != nil { return nil, err0 } - err = Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + err = Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error { rows, err = queryer.Query(query, dargs) if err == driver.ErrSkip { conn.attr.populate(ctx, query+msgErrSkip) @@ -301,6 +302,7 @@ type dbAttribute struct { driverVersion string user string dbname string + host string } func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, conn driver.Conn, dsn string, filtered bool) (*dbAttribute, error) { @@ -341,6 +343,14 @@ func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, con q.Del("password") u.RawQuery = q.Encode() + // In the case of known DSL sub segment name will be dbname@host + host, _, _ := net.SplitHostPort(u.Host) + if len(host) > 0 { + attr.host = "@" + host + } else { + attr.host = host + } + attr.url = u.String() if !strings.Contains(dsn, "//") { attr.url = attr.url[2:] @@ -562,7 +572,7 @@ func (stmt *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValu var result driver.Result var err error if execerContext, ok := stmt.Stmt.(driver.StmtExecContext); ok { - err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error { stmt.populate(ctx) var err error result, err = execerContext.ExecContext(ctx, args) @@ -578,7 +588,7 @@ func (stmt *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValu if err0 != nil { return nil, err0 } - err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error { stmt.populate(ctx) var err error result, err = stmt.Stmt.Exec(dargs) @@ -599,7 +609,7 @@ func (stmt *driverStmt) QueryContext(ctx context.Context, args []driver.NamedVal var result driver.Rows var err error if queryCtx, ok := stmt.Stmt.(driver.StmtQueryContext); ok { - err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error { stmt.populate(ctx) var err error result, err = queryCtx.QueryContext(ctx, args) @@ -615,7 +625,7 @@ func (stmt *driverStmt) QueryContext(ctx context.Context, args []driver.NamedVal if err0 != nil { return nil, err0 } - err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error { stmt.populate(ctx) var err error result, err = stmt.Stmt.Query(dargs) diff --git a/xray/sql_go110.go b/xray/sql_go110.go index a034b80d..dd9996bd 100644 --- a/xray/sql_go110.go +++ b/xray/sql_go110.go @@ -56,7 +56,7 @@ func (c *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { if err != nil { return nil, err } - err = Capture(ctx, attr.dbname, func(ctx context.Context) error { + err = Capture(ctx, attr.dbname+attr.host, func(ctx context.Context) error { attr.populate(ctx, "CONNECT") var err error rawConn, err = c.Connector.Connect(ctx) diff --git a/xray/sqlcontext_test.go b/xray/sqlcontext_test.go index 389719cd..79e8eeb1 100644 --- a/xray/sqlcontext_test.go +++ b/xray/sqlcontext_test.go @@ -74,49 +74,60 @@ func capturePing(dsn string) (*Segment, error) { func TestDSN(t *testing.T) { tc := []struct { - dsn string - url string - str string + dsn string + url string + str string + name string }{ { - dsn: "postgres://user@host:5432/database", - url: "postgres://user@host:5432/database", + dsn: "postgres://user@host:5432/database", + url: "postgres://user@host:5432/database", + name: "test database@host", }, { - dsn: "postgres://user:password@host:5432/database", - url: "postgres://user@host:5432/database", + dsn: "postgres://user:password@host:5432/database", + url: "postgres://user@host:5432/database", + name: "test database@host", }, { - dsn: "postgres://host:5432/database?password=password", - url: "postgres://host:5432/database", + dsn: "postgres://host:5432/database?password=password", + url: "postgres://host:5432/database", + name: "test database@host", }, { - dsn: "user:password@host:5432/database", - url: "user@host:5432/database", + dsn: "user:password@host:5432/database", + url: "user@host:5432/database", + name: "test database@host", }, { - dsn: "host:5432/database?password=password", - url: "host:5432/database", + dsn: "host:5432/database?password=password", + url: "host:5432/database", + name: "test database@host", }, { - dsn: "user%2Fpassword@host:5432/database", - url: "user@host:5432/database", + dsn: "user%2Fpassword@host:5432/database", + url: "user@host:5432/database", + name: "test database@host", }, { - dsn: "user/password@host:5432/database", - url: "user@host:5432/database", + dsn: "user/password@host:5432/database", + url: "user@host:5432/database", + name: "test database@host", }, { - dsn: "user=user database=database", - str: "user=user database=database", + dsn: "user=user database=database", + str: "user=user database=database", + name: "test database", }, { - dsn: "user=user password=password database=database", - str: "user=user database=database", + dsn: "user=user password=password database=database", + str: "user=user database=database", + name: "test database", }, { - dsn: "odbc:server=localhost;user id=sa;password={foo}};bar};otherthing=thing", - str: "odbc:server=localhost;user id=sa;otherthing=thing", + dsn: "odbc:server=localhost;user id=sa;password={foo}};bar};otherthing=thing", + str: "odbc:server=localhost;user id=sa;otherthing=thing", + name: "test database", }, } @@ -142,6 +153,7 @@ func TestDSN(t *testing.T) { assert.Equal(t, tt.str, subseg.SQL.ConnectionString) assert.Equal(t, "test version", subseg.SQL.DatabaseVersion) assert.Equal(t, "test user", subseg.SQL.User) + assert.Equal(t, tt.name, subseg.Name) assert.False(t, subseg.Throttle) assert.False(t, subseg.Error) assert.False(t, subseg.Fault)