Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #453 Add support for custom driver detector #461

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 85 additions & 65 deletions xray/sql_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"time"
)

const detectorDefaultKey = "default"

// we can't know that the original driver will return driver.ErrSkip in advance.
// so we add this message to the query if it returns driver.ErrSkip.
const msgErrSkip = " -- skip fast-path; continue as if unimplemented"
Expand All @@ -40,9 +42,27 @@ type namedValueChecker interface {
var (
muInitializedDrivers sync.Mutex
initializedDrivers map[string]struct{}
attrHook func(attr *dbAttribute) // for testing
attrHook func(attr *DBAttribute) // for testing
registeredDetectors map[string][]Detector
)

func initDetectors() {
RegisterSQLDetector("mysql", mysqlDetector)
RegisterSQLDetector("postgres", postgresDetector)
RegisterSQLDetector(detectorDefaultKey, postgresDetector, mysqlDetector, mssqlDetector, oracleDetector)
}

// RegisterSQLDetector - Register a detector for a specific SQL driver.
func RegisterSQLDetector(
driverName string,
detector ...Detector,
) {
if registeredDetectors == nil {
registeredDetectors = make(map[string][]Detector)
}
registeredDetectors[driverName] = append(registeredDetectors[driverName], detector...)
}

func initXRayDriver(driver, dsn string) error {
muInitializedDrivers.Lock()
defer muInitializedDrivers.Unlock()
Expand All @@ -62,8 +82,9 @@ func initXRayDriver(driver, dsn string) error {
Driver: db.Driver(),
baseName: driver,
})
initializedDrivers[driver] = struct{}{}
db.Close()
initializedDrivers[driver] = struct{}{}
initDetectors()
return nil
}

Expand All @@ -83,6 +104,8 @@ type driverDriver struct {
baseName string // the name of the base driver
}

type Detector func(ctx context.Context, conn driver.Conn, attr *DBAttribute) error

func (d *driverDriver) Open(dsn string) (driver.Conn, error) {
rawConn, err := d.Driver.Open(dsn)
if err != nil {
Expand All @@ -103,11 +126,11 @@ func (d *driverDriver) Open(dsn string) (driver.Conn, error) {

type driverConn struct {
driver.Conn
attr *dbAttribute
attr *DBAttribute
}

func (conn *driverConn) Ping(ctx context.Context) error {
return Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
return Capture(ctx, conn.attr.DBname+conn.attr.Host, func(ctx context.Context) error {
conn.attr.populate(ctx, "PING")
if p, ok := conn.Conn.(driver.Pinger); ok {
return p.Ping(ctx)
Expand Down Expand Up @@ -187,7 +210,7 @@ func (conn *driverConn) ExecContext(ctx context.Context, query string, args []dr
var err error
var result driver.Result
if execerCtx, ok := conn.Conn.(driver.ExecerContext); ok {
Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
Capture(ctx, conn.attr.DBname+conn.attr.Host, func(ctx context.Context) error {
result, err = execerCtx.ExecContext(ctx, query, args)
if err == driver.ErrSkip {
conn.attr.populate(ctx, query+msgErrSkip)
Expand All @@ -210,7 +233,7 @@ func (conn *driverConn) ExecContext(ctx context.Context, query string, args []dr
if err0 != nil {
return nil, err0
}
Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
Capture(ctx, conn.attr.DBname+conn.attr.Host, func(ctx context.Context) error {
var err error
result, err = execer.Exec(query, dargs)
if err == driver.ErrSkip {
Expand All @@ -232,7 +255,7 @@ func (conn *driverConn) QueryContext(ctx context.Context, query string, args []d
var err error
var rows driver.Rows
if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok {
Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
Capture(ctx, conn.attr.DBname+conn.attr.Host, func(ctx context.Context) error {
rows, err = queryerCtx.QueryContext(ctx, query, args)
if err == driver.ErrSkip {
conn.attr.populate(ctx, query+msgErrSkip)
Expand All @@ -255,7 +278,7 @@ func (conn *driverConn) QueryContext(ctx context.Context, query string, args []d
if err0 != nil {
return nil, err0
}
err = Capture(ctx, conn.attr.dbname+conn.attr.host, func(ctx context.Context) error {
err = Capture(ctx, conn.attr.DBname+conn.attr.Host, func(ctx context.Context) error {
rows, err = queryer.Query(query, dargs)
if err == driver.ErrSkip {
conn.attr.populate(ctx, query+msgErrSkip)
Expand Down Expand Up @@ -292,19 +315,19 @@ func (conn *driverConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
return defaultCheckNamedValue(nv)
}

type dbAttribute struct {
connectionString string
url string
databaseType string
databaseVersion string
driverVersion string
user string
dbname string
host string
type DBAttribute struct {
ConnectionString string
URL string
DatabaseType string
DatabaseVersion string
DriverVersion string
User string
DBname string
Host string
}

func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, conn driver.Conn, dsn string, filtered bool) (*dbAttribute, error) {
var attr dbAttribute
func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, conn driver.Conn, dsn string, filtered bool) (*DBAttribute, error) {
var attr DBAttribute

// Detect if DSN is a URL or not, set appropriate attribute
urlDsn := dsn
Expand Down Expand Up @@ -345,45 +368,42 @@ func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, con
// In the case of known DSL sub segment name will be dbname@host
host, _, _ := net.SplitHostPort(u.Host)
if len(host) > 0 {
attr.host = "@" + host
attr.Host = "@" + host
} else {
attr.host = host
attr.Host = host
}

attr.url = u.String()
attr.URL = u.String()
if !strings.Contains(dsn, "//") {
attr.url = attr.url[2:]
attr.URL = attr.URL[2:]
}
} else {
// We don't *think* it's a URL, so now we have to try our best to strip passwords from
// some unknown DSL. We attempt to detect whether it's space-delimited or semicolon-delimited
// then remove any keys with the name "password" or "pwd". This won't catch everything, but
// from surveying the current (Jan 2017) landscape of drivers it should catch most.
if filtered {
attr.connectionString = dsn
attr.ConnectionString = dsn
} else {
attr.connectionString = stripPasswords(dsn)
attr.ConnectionString = stripPasswords(dsn)
}
}

// Detect database type and use that to populate attributes
var detectors []func(ctx context.Context, conn driver.Conn, attr *dbAttribute) error
switch driverName {
case "postgres":
detectors = append(detectors, postgresDetector)
case "mysql":
detectors = append(detectors, mysqlDetector)
default:
detectors = append(detectors, postgresDetector, mysqlDetector, mssqlDetector, oracleDetector)
}
for _, detector := range detectors {
var driverDetectors []Detector
if v, ok := registeredDetectors[driverName]; ok {
driverDetectors = v
} else {
driverDetectors = registeredDetectors["default"]
}
for _, detector := range driverDetectors {
if detector(ctx, conn, &attr) == nil {
break
}
attr.databaseType = "Unknown"
attr.databaseVersion = "Unknown"
attr.user = "Unknown"
attr.dbname = "Unknown"
attr.DatabaseType = "Unknown"
attr.DatabaseVersion = "Unknown"
attr.User = "Unknown"
attr.DBname = "Unknown"
}

// There's no standard to get SQL driver version information
Expand All @@ -393,13 +413,13 @@ func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, con
}

if vd, ok := d.(versionedDriver); ok {
attr.driverVersion = vd.Version()
attr.DriverVersion = vd.Version()
} else {
t := reflect.TypeOf(d)
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
attr.driverVersion = t.PkgPath()
attr.DriverVersion = t.PkgPath()
}

if attrHook != nil {
Expand All @@ -408,39 +428,39 @@ func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, con
return &attr, nil
}

func postgresDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error {
attr.databaseType = "Postgres"
func postgresDetector(ctx context.Context, conn driver.Conn, attr *DBAttribute) error {
attr.DatabaseType = "Postgres"
return queryRow(
ctx, conn,
"SELECT version(), current_user, current_database()",
&attr.databaseVersion, &attr.user, &attr.dbname,
&attr.DatabaseVersion, &attr.User, &attr.DBname,
)
}

func mysqlDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error {
attr.databaseType = "MySQL"
func mysqlDetector(ctx context.Context, conn driver.Conn, attr *DBAttribute) error {
attr.DatabaseType = "MySQL"
return queryRow(
ctx, conn,
"SELECT version(), current_user(), database()",
&attr.databaseVersion, &attr.user, &attr.dbname,
&attr.DatabaseVersion, &attr.User, &attr.DBname,
)
}

func mssqlDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error {
attr.databaseType = "MS SQL"
func mssqlDetector(ctx context.Context, conn driver.Conn, attr *DBAttribute) error {
attr.DatabaseType = "MS SQL"
return queryRow(
ctx, conn,
"SELECT @@version, current_user, db_name()",
&attr.databaseVersion, &attr.user, &attr.dbname,
&attr.DatabaseVersion, &attr.User, &attr.DBname,
)
}

func oracleDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error {
attr.databaseType = "Oracle"
func oracleDetector(ctx context.Context, conn driver.Conn, attr *DBAttribute) error {
attr.DatabaseType = "Oracle"
return queryRow(
ctx, conn,
"SELECT version FROM v$instance UNION SELECT user, ora_database_name FROM dual",
&attr.databaseVersion, &attr.user, &attr.dbname,
&attr.DatabaseVersion, &attr.User, &attr.DBname,
)
}

Expand Down Expand Up @@ -516,7 +536,7 @@ func queryRow(ctx context.Context, conn driver.Conn, query string, dest ...*stri
return nil
}

func (attr *dbAttribute) populate(ctx context.Context, query string) {
func (attr *DBAttribute) populate(ctx context.Context, query string) {
seg := GetSegment(ctx)

if seg == nil {
Expand All @@ -526,12 +546,12 @@ func (attr *dbAttribute) populate(ctx context.Context, query string) {

seg.Lock()
seg.Namespace = "remote"
seg.GetSQL().ConnectionString = attr.connectionString
seg.GetSQL().URL = attr.url
seg.GetSQL().DatabaseType = attr.databaseType
seg.GetSQL().DatabaseVersion = attr.databaseVersion
seg.GetSQL().DriverVersion = attr.driverVersion
seg.GetSQL().User = attr.user
seg.GetSQL().ConnectionString = attr.ConnectionString
seg.GetSQL().URL = attr.URL
seg.GetSQL().DatabaseType = attr.DatabaseType
seg.GetSQL().DatabaseVersion = attr.DatabaseVersion
seg.GetSQL().DriverVersion = attr.DriverVersion
seg.GetSQL().User = attr.User
seg.GetSQL().SanitizedQuery = query
seg.Unlock()
}
Expand All @@ -551,7 +571,7 @@ func (tx *driverTx) Rollback() error {
type driverStmt struct {
driver.Stmt
conn *driverConn
attr *dbAttribute
attr *DBAttribute
query string
}

Expand All @@ -571,7 +591,7 @@ func (stmt *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValu
var result driver.Result
var err error
if execerContext, ok := stmt.Stmt.(driver.StmtExecContext); ok {
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.DBname+stmt.attr.Host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = execerContext.ExecContext(ctx, args)
Expand All @@ -587,7 +607,7 @@ func (stmt *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValu
if err0 != nil {
return nil, err0
}
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.DBname+stmt.attr.Host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = stmt.Stmt.Exec(dargs)
Expand All @@ -608,7 +628,7 @@ func (stmt *driverStmt) QueryContext(ctx context.Context, args []driver.NamedVal
var result driver.Rows
var err error
if queryCtx, ok := stmt.Stmt.(driver.StmtQueryContext); ok {
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.DBname+stmt.attr.Host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = queryCtx.QueryContext(ctx, args)
Expand All @@ -624,7 +644,7 @@ func (stmt *driverStmt) QueryContext(ctx context.Context, args []driver.NamedVal
if err0 != nil {
return nil, err0
}
err = Capture(ctx, stmt.attr.dbname+stmt.attr.host, func(ctx context.Context) error {
err = Capture(ctx, stmt.attr.DBname+stmt.attr.Host, func(ctx context.Context) error {
stmt.populate(ctx)
var err error
result, err = stmt.Stmt.Query(dargs)
Expand Down
7 changes: 4 additions & 3 deletions xray/sql_go110.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

//go:build go1.10
// +build go1.10

package xray
Expand Down Expand Up @@ -47,7 +48,7 @@ type driverConnector struct {
name string

mu sync.RWMutex
attr *dbAttribute
attr *DBAttribute
}

func (c *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
Expand All @@ -56,7 +57,7 @@ func (c *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
if err != nil {
return nil, err
}
err = Capture(ctx, attr.dbname+attr.host, func(ctx context.Context) error {
err = Capture(ctx, attr.DBname+attr.Host, func(ctx context.Context) error {
attr.populate(ctx, "CONNECT")
var err error
rawConn, err = c.Connector.Connect(ctx)
Expand All @@ -73,7 +74,7 @@ func (c *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
return conn, nil
}

func (c *driverConnector) getAttr(ctx context.Context) (*dbAttribute, error) {
func (c *driverConnector) getAttr(ctx context.Context) (*DBAttribute, error) {
c.mu.RLock()
attr := c.attr
c.mu.RUnlock()
Expand Down
Loading