Skip to content

Commit

Permalink
sqlserver: update test DSN string to be more flexible
Browse files Browse the repository at this point in the history
Support URL for test DSN, allows setting properties (such as
encryption).

Also fix URL query parameter parsing by forcing keys to lower case
that are expected.
  • Loading branch information
kardianos committed May 23, 2017
1 parent e3bd523 commit c2a55c8
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 68 deletions.
5 changes: 4 additions & 1 deletion bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ import (
)

func TestBulkcopy(t *testing.T) {

// TDS level Bulk Insert is not supported on Azure SQL Server.
if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
t.Skip("TDS level bulk copy is not supported on Azure SQL Server")
}
type testValue struct {
colname string
val interface{}
Expand Down
8 changes: 4 additions & 4 deletions queries_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ func TestPinger(t *testing.T) {
func TestQueryCancelLowLevel(t *testing.T) {
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr())
conn, err := drv.open(makeConnStr(t).String())
if err != nil {
t.Fatalf("Open failed with error %v", err)
}
Expand Down Expand Up @@ -547,7 +547,7 @@ func TestDriverParams(t *testing.T) {

for cmdIndex, cmd := range list {
t.Run(cmd.Name, func(t *testing.T) {
db, err := sql.Open(cmd.Driver, makeConnStr())
db, err := sql.Open(cmd.Driver, makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver %q", cmd.Driver)
}
Expand Down Expand Up @@ -700,7 +700,7 @@ func TestDisconnect1(t *testing.T) {
}()
return di
}
db, err := sql.Open("sqlserver", makeConnStr())
db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -762,7 +762,7 @@ func TestDisconnect2(t *testing.T) {
}()
return di
}
db, err := sql.Open("sqlserver", makeConnStr())
db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatal(err)
}
Expand Down
58 changes: 17 additions & 41 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,11 @@ func TestShortTimeout(t *testing.T) {
}
checkConnStr(t)
SetLogger(testLogger{t})
dsn := makeConnStr() + ";Connection Timeout=2"
conn, err := sql.Open("mssql", dsn)
dsn := makeConnStr(t)
dsnParams := dsn.Query()
dsnParams.Set("Connection Timeout", "2")
dsn.RawQuery = dsnParams.Encode()
conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
Expand Down Expand Up @@ -741,33 +744,6 @@ func TestBug32(t *testing.T) {
}
}

/*
func TestLogging(t *testing.T) {
flags := log.Flags()
defer func() {
log.SetFlags(flags)
log.SetOutput(os.Stderr)
}()
log.SetFlags(0)
var b bytes.Buffer
log.SetOutput(&b)
dsn := makeConnStr() + ";Log=2"
conn, err := sql.Open("mssql", dsn)
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
_, err = conn.Exec("print 'test'")
if err != nil {
t.Fatal("Exec print failed", err.Error())
}
if b.String() != "test\n" {
t.Fatal("logging test failed, got", b.String())
}
}
*/

func TestIgnoreEmptyResults(t *testing.T) {
conn := open(t)
defer conn.Close()
Expand All @@ -791,7 +767,7 @@ func TestIgnoreEmptyResults(t *testing.T) {
func TestMssqlStmt_SetQueryNotification(t *testing.T) {
checkConnStr(t)
mssqldriver := driverWithProcess(t)
cn, err := mssqldriver.Open(makeConnStr())
cn, err := mssqldriver.Open(makeConnStr(t).String())
stmt, err := cn.Prepare("SELECT 1")
if err != nil {
t.Error("Connection failed", err)
Expand Down Expand Up @@ -866,7 +842,7 @@ func TestConnectionClosing(t *testing.T) {
func TestBeginTranError(t *testing.T) {
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr())
conn, err := drv.open(makeConnStr(t).String())
if err != nil {
t.Fatalf("Open failed with error %v", err)
}
Expand All @@ -882,7 +858,7 @@ func TestBeginTranError(t *testing.T) {
}

// reopen connection
conn, err = drv.open(makeConnStr())
conn, err = drv.open(makeConnStr(t).String())
if err != nil {
t.Fatalf("Open failed with error %v", err)
}
Expand All @@ -905,7 +881,7 @@ func TestBeginTranError(t *testing.T) {
func TestCommitTranError(t *testing.T) {
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr())
conn, err := drv.open(makeConnStr(t).String())
if err != nil {
t.Fatalf("Open failed with error %v", err)
}
Expand All @@ -921,7 +897,7 @@ func TestCommitTranError(t *testing.T) {
}

// reopen connection
conn, err = drv.open(makeConnStr())
conn, err = drv.open(makeConnStr(t).String())
if err != nil {
t.Fatalf("Open failed with error %v", err)
}
Expand All @@ -941,7 +917,7 @@ func TestCommitTranError(t *testing.T) {
}

// reopen connection
conn, err = drv.open(makeConnStr())
conn, err = drv.open(makeConnStr(t).String())
defer conn.Close()
if err != nil {
t.Fatalf("Open failed with error %v", err)
Expand All @@ -959,7 +935,7 @@ func TestCommitTranError(t *testing.T) {
func TestRollbackTranError(t *testing.T) {
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr())
conn, err := drv.open(makeConnStr(t).String())
if err != nil {
t.Fatalf("Open failed with error %v", err)
}
Expand All @@ -975,7 +951,7 @@ func TestRollbackTranError(t *testing.T) {
}

// reopen connection
conn, err = drv.open(makeConnStr())
conn, err = drv.open(makeConnStr(t).String())
if err != nil {
t.Fatalf("Open failed with error %v", err)
}
Expand All @@ -995,7 +971,7 @@ func TestRollbackTranError(t *testing.T) {
}

// reopen connection
conn, err = drv.open(makeConnStr())
conn, err = drv.open(makeConnStr(t).String())
defer conn.Close()
if err != nil {
t.Fatalf("Open failed with error %v", err)
Expand All @@ -1013,7 +989,7 @@ func TestRollbackTranError(t *testing.T) {
func TestSendQueryErrors(t *testing.T) {
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr())
conn, err := drv.open(makeConnStr(t).String())
if err != nil {
t.FailNow()
}
Expand Down Expand Up @@ -1053,7 +1029,7 @@ func TestSendQueryErrors(t *testing.T) {
func TestProcessQueryErrors(t *testing.T) {
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr())
conn, err := drv.open(makeConnStr(t).String())
if err != nil {
t.Fatal("open expected to succeed, but it failed with", err)
}
Expand All @@ -1080,7 +1056,7 @@ func TestProcessQueryErrors(t *testing.T) {
func TestSendExecErrors(t *testing.T) {
checkConnStr(t)
drv := driverWithProcess(t)
conn, err := drv.open(makeConnStr())
conn, err := drv.open(makeConnStr(t).String())
if err != nil {
t.FailNow()
}
Expand Down
2 changes: 1 addition & 1 deletion tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
if len(v) > 1 {
return res, fmt.Errorf("key %s provided more than once", k)
}
res[k] = v[0]
res[strings.ToLower(k)] = v[0]
}

return res, nil
Expand Down
82 changes: 61 additions & 21 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/hex"
"fmt"
"net/url"
"os"
"testing"
"time"
Expand Down Expand Up @@ -72,7 +73,7 @@ func TestSendLogin(t *testing.T) {

func TestSendSqlBatch(t *testing.T) {
checkConnStr(t)
p, err := parseConnectParams(makeConnStr())
p, err := parseConnectParams(makeConnStr(t).String())
if err != nil {
t.Error("parseConnectParams failed:", err.Error())
return
Expand Down Expand Up @@ -127,21 +128,38 @@ loop:
}

func checkConnStr(t *testing.T) {
if len(os.Getenv("SQLSERVER_DSN")) > 0 {
return
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 {
return
}
t.Skip("no database connection string")
}

func makeConnStr() string {
addr := os.Getenv("HOST")
instance := os.Getenv("INSTANCE")
user := os.Getenv("SQLUSER")
password := os.Getenv("SQLPASSWORD")
database := os.Getenv("DATABASE")
return fmt.Sprintf(
"Server=%s\\%s;User Id=%s;Password=%s;Database=%s;log=127",
addr, instance, user, password, database)
// makeConnStr returns a URL struct so it may be modified by various
// tests before used as a DSN.
func makeConnStr(t *testing.T) *url.URL {
dsn := os.Getenv("SQLSERVER_DSN")
if len(dsn) > 0 {
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatal("unable to parse SQLSERVER_DSN as URL", err)
}
values := parsed.Query()
values.Set("log", "127")
parsed.RawQuery = values.Encode()
return parsed
}
values := url.Values{}
values.Set("log", "127")
values.Set("database", os.Getenv("DATABASE"))
return &url.URL{
Host: os.Getenv("HOST"),
Path: os.Getenv("INSTANCE"),
User: url.UserPassword(os.Getenv("SQLUSER"), os.Getenv("SQLPASSWORD")),
RawQuery: values.Encode(),
}
}

type testLogger struct {
Expand All @@ -159,7 +177,7 @@ func (l testLogger) Println(v ...interface{}) {
func open(t *testing.T) *sql.DB {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("mssql", makeConnStr())
conn, err := sql.Open("mssql", makeConnStr(t).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil
Expand All @@ -171,7 +189,7 @@ func open(t *testing.T) *sql.DB {
func TestConnect(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("mssql", makeConnStr())
conn, err := sql.Open("mssql", makeConnStr(t).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return
Expand All @@ -180,13 +198,22 @@ func TestConnect(t *testing.T) {
}

func TestBadConnect(t *testing.T) {
badDsns := []string{
//"Server=badhost",
fmt.Sprintf("Server=%s\\%s;User ID=baduser;Password=badpwd",
os.Getenv("HOST"), os.Getenv("INSTANCE")),
var badDSNs []string

if parsed, err := url.Parse(os.Getenv("SQLSERVER_DSN")); err == nil {
parsed.User = url.UserPassword("baduser", "badpwd")
badDSNs = append(badDSNs, parsed.String())
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("INSTANCE")) > 0 {
badDSNs = append(badDSNs,
fmt.Sprintf(
"Server=%s\\%s;User ID=baduser;Password=badpwd",
os.Getenv("HOST"), os.Getenv("INSTANCE"),
),
)
}
SetLogger(testLogger{t})
for _, badDsn := range badDsns {
for _, badDsn := range badDSNs {
conn, err := sql.Open("mssql", badDsn)
if err != nil {
t.Error("Open connection failed:", err.Error())
Expand Down Expand Up @@ -321,8 +348,15 @@ func TestPing(t *testing.T) {
func TestSecureWithInvalidHostName(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
dsn := makeConnStr() + ";Encrypt=true;TrustServerCertificate=false;hostNameInCertificate=foo.bar"
conn, err := sql.Open("mssql", dsn)

dsn := makeConnStr(t)
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "true")
dsnParams.Set("TrustServerCertificate", "false")
dsnParams.Set("hostNameInCertificate", "foo.bar")
dsn.RawQuery = dsnParams.Encode()

conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
Expand All @@ -336,8 +370,14 @@ func TestSecureWithInvalidHostName(t *testing.T) {
func TestSecureConnection(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
dsn := makeConnStr() + ";Encrypt=true;TrustServerCertificate=true"
conn, err := sql.Open("mssql", dsn)

dsn := makeConnStr(t)
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "true")
dsnParams.Set("TrustServerCertificate", "true")
dsn.RawQuery = dsnParams.Encode()

conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
Expand Down

0 comments on commit c2a55c8

Please sign in to comment.