From 34d7302555700bcaa3343ee597e088c0bb9c9903 Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Fri, 18 Dec 2020 07:51:06 -0800 Subject: [PATCH 1/6] Export the token connector creators --- fedauth.go | 4 ++-- tds_login_test.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fedauth.go b/fedauth.go index 86fed253..a70117ff 100644 --- a/fedauth.go +++ b/fedauth.go @@ -39,7 +39,7 @@ const ( // service specified and obtain the appropriate token, or return an error // to indicate why a token is not available. // The returned connector may be used with sql.OpenDB. -func newSecurityTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) { +func NewSecurityTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) { if tokenProvider == nil { return nil, errors.New("mssql: tokenProvider cannot be nil") } @@ -64,7 +64,7 @@ func newSecurityTokenConnector(dsn string, tokenProvider func(ctx context.Contex // to indicate why a token is not available. // // The returned connector may be used with sql.OpenDB. -func newActiveDirectoryTokenConnector(dsn string, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) { +func NewActiveDirectoryTokenConnector(dsn string, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) { if tokenProvider == nil { return nil, errors.New("mssql: tokenProvider cannot be nil") } diff --git a/tds_login_test.go b/tds_login_test.go index 9f18fc0d..d6b4cdc6 100644 --- a/tds_login_test.go +++ b/tds_login_test.go @@ -155,7 +155,7 @@ func TestLoginWithSQLServerAuth(t *testing.T) { } func TestLoginWithSecurityTokenAuth(t *testing.T) { - conn, err := newSecurityTokenConnector("sqlserver://localhost:1433?Workstation ID=localhost&log=128", + conn, err := NewSecurityTokenConnector("sqlserver://localhost:1433?Workstation ID=localhost&log=128", func(ctx context.Context) (string, error) { return "", nil }, @@ -207,7 +207,7 @@ func TestLoginWithSecurityTokenAuth(t *testing.T) { } func TestLoginWithADALUsernamePasswordAuth(t *testing.T) { - conn, err := newActiveDirectoryTokenConnector( + conn, err := NewActiveDirectoryTokenConnector( "sqlserver://localhost:1433?Workstation ID=localhost&log=128", fedAuthADALWorkflowPassword, func(ctx context.Context, serverSPN, stsURL string) (string, error) { @@ -272,7 +272,7 @@ func TestLoginWithADALUsernamePasswordAuth(t *testing.T) { } func TestLoginWithADALManagedIdentityAuth(t *testing.T) { - conn, err := newActiveDirectoryTokenConnector( + conn, err := NewActiveDirectoryTokenConnector( "sqlserver://localhost:1433?Workstation ID=localhost&log=128", fedAuthADALWorkflowMSI, func(ctx context.Context, serverSPN, stsURL string) (string, error) { From 69e492e5fac1fadd220361334ab9d1aae3b17b32 Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Wed, 16 Dec 2020 11:06:53 -0800 Subject: [PATCH 2/6] Add connection logging to help with debugging --- conn_str.go | 7 +++ conn_str_test.go | 5 +- log_conn.go | 80 +++++++++++++++++++++++++++++++ log_conn_test.go | 121 +++++++++++++++++++++++++++++++++++++++++++++++ tds.go | 20 +++++++- 5 files changed, 230 insertions(+), 3 deletions(-) create mode 100644 log_conn.go create mode 100644 log_conn_test.go diff --git a/conn_str.go b/conn_str.go index d7d9e06a..decb7632 100644 --- a/conn_str.go +++ b/conn_str.go @@ -1,6 +1,7 @@ package mssql import ( + "errors" "fmt" "net" "net/url" @@ -39,6 +40,7 @@ type connectParams struct { packetSize uint16 fedAuthLibrary int fedAuthADALWorkflow byte + tlsKeyLogFile string } // default packet size for TDS buffer @@ -235,6 +237,11 @@ func parseConnectParams(dsn string) (connectParams, error) { } } + p.tlsKeyLogFile, ok = params["tls key log file"] + if ok && p.tlsKeyLogFile != "" && p.disableEncryption { + return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled") + } + return p, nil } diff --git a/conn_str_test.go b/conn_str_test.go index bb6e2682..a7f953a3 100644 --- a/conn_str_test.go +++ b/conn_str_test.go @@ -67,6 +67,7 @@ func TestValidConnectionString(t *testing.T) { {"trustservercertificate=false", func(p connectParams) bool { return !p.trustServerCertificate }}, {"certificate=abc", func(p connectParams) bool { return p.certificate == "abc" }}, {"hostnameincertificate=abc", func(p connectParams) bool { return p.hostInCertificate == "abc" }}, + {"tls key log file=tls.log", func(p connectParams) bool { return p.tlsKeyLogFile == "tls.log" }}, {"connection timeout=3;dial timeout=4;keepalive=5", func(p connectParams) bool { return p.conn_timeout == 3*time.Second && p.dial_timeout == 4*time.Second && p.keepAlive == 5*time.Second }}, @@ -186,10 +187,10 @@ func testConnParams(t testing.TB) connectParams { } if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 { return connectParams{ - host: os.Getenv("HOST"), + host: os.Getenv("HOST"), instance: os.Getenv("INSTANCE"), database: os.Getenv("DATABASE"), - user: os.Getenv("SQLUSER"), + user: os.Getenv("SQLUSER"), password: os.Getenv("SQLPASSWORD"), logFlags: logFlags, } diff --git a/log_conn.go b/log_conn.go new file mode 100644 index 00000000..4777e4c1 --- /dev/null +++ b/log_conn.go @@ -0,0 +1,80 @@ +package mssql + +import ( + "encoding/hex" + "net" + "strings" + "time" +) + +type connLogger struct { + conn net.Conn + readKind, writeKind string + readCount, writeCount int + logger Logger +} + +var _ net.Conn = &connLogger{} + +func newConnLogger(conn net.Conn, kind string, logger Logger) net.Conn { + if len(kind) > 0 && !strings.HasPrefix(kind, " ") { + kind = " " + kind + } + + cl := &connLogger{ + conn: conn, + readKind: "R" + kind, + writeKind: "W" + kind, + logger: logger, + } + + return cl +} + +func (cl *connLogger) Read(p []byte) (n int, err error) { + n, err = cl.conn.Read(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.readKind, cl.readCount, dump) + cl.readCount += n + } + + return +} + +func (cl *connLogger) Write(p []byte) (n int, err error) { + n, err = cl.conn.Write(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.writeKind, cl.writeCount, dump) + cl.writeCount += n + } + + return +} + +func (cl *connLogger) Close() (err error) { + return cl.conn.Close() +} + +func (cl *connLogger) LocalAddr() net.Addr { + return cl.conn.LocalAddr() +} + +func (cl *connLogger) RemoteAddr() net.Addr { + return cl.conn.RemoteAddr() +} + +func (cl *connLogger) SetDeadline(t time.Time) error { + return cl.conn.SetDeadline(t) +} + +func (cl *connLogger) SetReadDeadline(t time.Time) error { + return cl.conn.SetReadDeadline(t) +} + +func (cl *connLogger) SetWriteDeadline(t time.Time) error { + return cl.conn.SetWriteDeadline(t) +} diff --git a/log_conn_test.go b/log_conn_test.go new file mode 100644 index 00000000..2e4b91d5 --- /dev/null +++ b/log_conn_test.go @@ -0,0 +1,121 @@ +package mssql + +import ( + "net" + "sync/atomic" + "testing" + "time" +) + +func TestConnLoggerOperations(t *testing.T) { + clt := &connLoggerTest{} + cl := newConnLogger(clt, "test", nullLogger{}) + packet := append(make([]byte, 0, 10), 1, 2, 3, 4, 5) + n, err := cl.Read(packet) + if n != 10 || err != nil { + t.Error("Unexpected return value from call to Read()") + } + + n, err = cl.Write(packet) + if n != 5 || err != nil { + t.Error("Unexpected return value from call to Write()") + } + + if cl.Close() != nil { + t.Error("Unexpected return value from call to Close()") + } + + if cl.LocalAddr() == nil { + t.Error("Unexpected return value from call to LocalAddr()") + } + + if cl.RemoteAddr() == nil { + t.Error("Unexpected return value from call to RemoteAddr()") + } + + if cl.SetDeadline(time.Now()) != nil { + t.Error("Unexpected return value from call to SetDeadline()") + } + + if cl.SetReadDeadline(time.Now()) != nil { + t.Error("Unexpected return value from call to SetReadDeadline()") + } + + if cl.SetWriteDeadline(time.Now()) != nil { + t.Error("Unexpected return value from call to SetWriteDeadline()") + } + + if atomic.LoadInt32(&clt.calls) != 8 { + t.Error("Unexpected number of calls recorded") + } +} + +type connLoggerTest struct { + calls int32 +} + +var _ net.Conn = &connLoggerTest{} + +type addressTest struct { +} + +var _ net.Addr = &addressTest{} + +type nullLogger struct { +} + +var _ Logger = nullLogger{} + +func (n nullLogger) Printf(format string, v ...interface{}) { +} + +func (n nullLogger) Println(v ...interface{}) { +} + +func (a *addressTest) Network() string { + return "test" +} + +func (a *addressTest) String() string { + return "test" +} + +func (cl *connLoggerTest) Read(p []byte) (int, error) { + atomic.AddInt32(&cl.calls, 1) + return cap(p), nil +} + +func (cl *connLoggerTest) Write(p []byte) (int, error) { + atomic.AddInt32(&cl.calls, 1) + return len(p), nil +} + +func (cl *connLoggerTest) Close() error { + atomic.AddInt32(&cl.calls, 1) + return nil +} + +func (cl *connLoggerTest) LocalAddr() net.Addr { + atomic.AddInt32(&cl.calls, 1) + return &addressTest{} +} + +func (cl *connLoggerTest) RemoteAddr() net.Addr { + atomic.AddInt32(&cl.calls, 1) + return &addressTest{} +} + +func (cl *connLoggerTest) SetDeadline(t time.Time) error { + atomic.AddInt32(&cl.calls, 1) + return nil +} + +func (cl *connLoggerTest) SetReadDeadline(t time.Time) error { + atomic.AddInt32(&cl.calls, 1) + return nil +} + +func (cl *connLoggerTest) SetWriteDeadline(t time.Time) error { + atomic.AddInt32(&cl.calls, 1) + return nil +} diff --git a/tds.go b/tds.go index e1b63300..a41a9ddf 100644 --- a/tds.go +++ b/tds.go @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "net" + "os" "sort" "strconv" "strings" @@ -152,6 +153,7 @@ const ( logParams = 16 logTransaction = 32 logDebug = 64 + logTraffic = 128 ) type columnStruct struct { @@ -1059,6 +1061,10 @@ initiate_connection: return nil, err } + if p.logFlags&logTraffic != 0 { + conn = newConnLogger(conn, "TCP", log) + } + toconn := newTimeoutConn(conn, p.conn_timeout) outbuf := newTdsBuffer(p.packetSize, toconn) @@ -1104,6 +1110,14 @@ initiate_connection: if p.trustServerCertificate { config.InsecureSkipVerify = true } + if p.tlsKeyLogFile != "" { + if w, err := os.OpenFile(p.tlsKeyLogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600); err == nil { + defer w.Close() + config.KeyLogWriter = w + } else { + return nil, fmt.Errorf("Cannot open TLS key log file %s: %v", p.tlsKeyLogFile, err) + } + } config.ServerName = p.hostInCertificate // fix for https://github.com/denisenkom/go-mssqldb/issues/166 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, @@ -1116,7 +1130,11 @@ initiate_connection: tlsConn := tls.Client(&passthrough, &config) err = tlsConn.Handshake() passthrough.c = toconn - outbuf.transport = tlsConn + if sess.logFlags&logTraffic != 0 { + outbuf.transport = newConnLogger(tlsConn, "TLS", log) + } else { + outbuf.transport = tlsConn + } if err != nil { return nil, fmt.Errorf("TLS Handshake failed: %v", err) } From 2399f45ee46049ce8caf82ac30c432cfd64ec6ee Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Wed, 16 Dec 2020 19:31:53 -0800 Subject: [PATCH 3/6] Implement Azure AD token provider. --- azuread/adal_tokens.go | 135 +++++++++++++++ azuread/adal_tokens_test.go | 203 ++++++++++++++++++++++ azuread/configuration.go | 212 +++++++++++++++++++++++ azuread/configuration_test.go | 317 ++++++++++++++++++++++++++++++++++ azuread/conn_str.go | 55 ++++++ azuread/driver.go | 65 +++++++ go.mod | 3 +- go.sum | 20 ++- 8 files changed, 1007 insertions(+), 3 deletions(-) create mode 100644 azuread/adal_tokens.go create mode 100644 azuread/adal_tokens_test.go create mode 100644 azuread/configuration.go create mode 100644 azuread/configuration_test.go create mode 100644 azuread/conn_str.go create mode 100644 azuread/driver.go diff --git a/azuread/adal_tokens.go b/azuread/adal_tokens.go new file mode 100644 index 00000000..9f455323 --- /dev/null +++ b/azuread/adal_tokens.go @@ -0,0 +1,135 @@ +package azuread + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "fmt" + "os" + + "github.com/Azure/go-autorest/autorest/adal" +) + +// When the security token library is used, the token is obtained without input +// from the server, so the AD endpoint and Azure SQL resource URI are provided +// from the constants below. +var ( + // activeDirectoryEndpoint is the security token service URL to use when + // the server does not provide the URL. + activeDirectoryEndpoint = "https://login.microsoftonline.com/" +) + +func init() { + endpoint := os.Getenv("AZURE_AD_STS_URL") + if endpoint != "" { + activeDirectoryEndpoint = endpoint + } +} + +const ( + // azureSQLResource is the AD resource to use when the server does not + // provide the resource. + azureSQLResource = "https://database.windows.net/" + + // driverClientID is the AD client ID to use when performing a username + // and password login. + driverClientID = "7f98cb04-cd1e-40df-9140-3bf7e2cea4db" +) + +func retrieveToken(ctx context.Context, token *adal.ServicePrincipalToken) (string, error) { + err := token.RefreshWithContext(ctx) + if err != nil { + err = fmt.Errorf("Failed to refresh token: %v", err) + return "", err + } + + return token.Token().AccessToken, nil +} + +// SecurityTokenFromCertificate obtains a security token using a certificate and RSA private key. +func SecurityTokenFromCertificate(ctx context.Context, clientID, tenantID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey) (string, error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // tenant ID is resolved. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) + if err != nil { + err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", + activeDirectoryEndpoint, tenantID, err) + return "", err + } + + token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, privateKey, azureSQLResource) + if err != nil { + err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", clientID, tenantID, err) + return "", err + } + + return retrieveToken(ctx, token) +} + +// SecurityTokenFromSecret obtains a security token using a client ID and secret. +func SecurityTokenFromSecret(ctx context.Context, clientID, tenantID, clientSecret string) (string, error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // tenant ID is resolved. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) + if err != nil { + err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", + activeDirectoryEndpoint, tenantID, err) + return "", err + } + + token, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, azureSQLResource) + + if err != nil { + err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", clientID, tenantID, err) + return "", err + } + + return retrieveToken(ctx, token) +} + +// ActiveDirectoryTokenFromPassword obtains a security token using an Active Directory username and password. +func ActiveDirectoryTokenFromPassword(ctx context.Context, serverSPN, stsURL, user, password string) (string, error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // STS URL is resolved. However, the STS URL is normally absolute and + // the activeDirectoryEndpoint URL is completely ignored. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, stsURL) + if err != nil { + err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", + activeDirectoryEndpoint, stsURL, err) + return "", err + } + + token, err := adal.NewServicePrincipalTokenFromUsernamePassword(*oauthConfig, driverClientID, user, password, serverSPN) + + if err != nil { + err = fmt.Errorf("Failed to obtain token for user %s for resource %s from service %s: %v", user, serverSPN, stsURL, err) + return "", err + } + + return retrieveToken(ctx, token) +} + +// ActiveDirectoryTokenFromIdentity obtains a security token the managed identity service. +func ActiveDirectoryTokenFromIdentity(ctx context.Context, serverSPN, stsURL, clientID string) (string, error) { + msiEndpoint, err := adal.GetMSIEndpoint() + if err != nil { + return "", err + } + + var token *adal.ServicePrincipalToken + var access string + if clientID == "" { + access = "system identity" + token, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, serverSPN) + } else { + access = "user-assigned identity " + clientID + token, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, serverSPN, clientID) + } + + if err != nil { + err = fmt.Errorf("Failed to obtain token for %s for resource %s from service %s: %v", access, serverSPN, stsURL, err) + return "", err + } + + return retrieveToken(ctx, token) +} diff --git a/azuread/adal_tokens_test.go b/azuread/adal_tokens_test.go new file mode 100644 index 00000000..ddc90f22 --- /dev/null +++ b/azuread/adal_tokens_test.go @@ -0,0 +1,203 @@ +package azuread + +import ( + "context" + "database/sql" + "net/url" + "os" + "strings" + "testing" + + mssql "github.com/denisenkom/go-mssqldb" +) + +type testLogger struct { + t *testing.T +} + +func (l testLogger) Printf(format string, v ...interface{}) { + l.t.Logf(format, v...) +} + +func (l testLogger) Println(v ...interface{}) { + l.t.Log(v...) +} + +func checkAzureSQLEnvironment(fedAuth string, t *testing.T) (*url.URL, string) { + u := &url.URL{ + Scheme: "sqlserver", + Host: os.Getenv("SQL_SERVER"), + } + + if u.Host == "" { + t.Skip("Azure SQL Server name not provided in SQL_SERVER environment variable") + } + + database := os.Getenv("SQL_DATABASE") + if database == "" { + t.Skip("Azure SQL database name not provided in SQL_DATABASE environment variable") + } + + tenantID := os.Getenv("AZURE_TENANT_ID") + if tenantID == "" { + t.Skip("Azure tenant ID not provided in AZURE_TENANT_ID environment variable") + } + + query := u.Query() + + query.Add("database", database) + query.Add("encrypt", "true") + query.Add("fedauth", fedAuth) + + u.RawQuery = query.Encode() + + return u, tenantID +} + +func checkFedAuthUserPassword(t *testing.T) *url.URL { + u, _ := checkAzureSQLEnvironment("ActiveDirectoryPassword", t) + + username := os.Getenv("SQL_AD_ADMIN_USER") + password := os.Getenv("SQL_AD_ADMIN_PASSWORD") + + if username == "" || password == "" { + t.Skip("Username and password login requires SQL_AD_ADMIN_USER and SQL_AD_ADMIN_PASSWORD environment variables") + } + + u.User = url.UserPassword(username, password) + + return u +} + +func checkFedAuthAppPassword(t *testing.T) *url.URL { + u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryApplication", t) + + appClientID := os.Getenv("APP_SP_CLIENT_ID") + appPassword := os.Getenv("APP_SP_CLIENT_SECRET") + + if appClientID == "" || appPassword == "" { + t.Skip("Application (service principal) login requires APP_SP_CLIENT_ID and APP_SP_CLIENT_SECRET environment variables") + } + + u.User = url.UserPassword(appClientID+"@"+tenantID, appPassword) + + return u +} + +func checkFedAuthAppCertPath(t *testing.T) *url.URL { + u := checkFedAuthAppPassword(t) + + appCertPath := os.Getenv("APP_SP_CLIENT_CERT") + if appCertPath == "" { + t.Skip("Application (service principal) certificate login requires APP_SP_CLIENT_CERT with path to certificate") + } + + query := u.Query() + query.Add("clientcertpath", appCertPath) + u.RawQuery = query.Encode() + + return u +} + +func checkFedAuthVMSystemID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t) + + vmClientID := os.Getenv("VM_CLIENT_ID") + if vmClientID == "" { + t.Skip("System-assigned identity login test requires VM_CLIENT_ID environment variable") + } + + return u, vmClientID + "@" + tenantID +} + +func checkFedAuthVMUserAssignedID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t) + + uaClientID := os.Getenv("UA_CLIENT_ID") + if uaClientID == "" { + t.Skip("User-assigned identity login test requires UA_CLIENT_ID environment variable") + } + + u.User = url.User(uaClientID) + + return u, uaClientID + "@" + tenantID +} + +func checkLoggedInUser(expected string, u *url.URL, t *testing.T) { + db, err := sql.Open(DriverName, u.String()) + if err != nil { + t.Fatalf("Failed to open URL %v: %v", u, err) + } + + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sql := "SELECT SUSER_NAME()" + + stmt, err := db.PrepareContext(ctx, sql) + if err != nil { + t.Fatalf("Failed to prepare query %s: %v", sql, err) + } + + defer stmt.Close() + + rows, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatalf("Failed to fetch query result for %s: %v", sql, err) + } + + defer rows.Close() + + var username string + if !rows.Next() { + t.Fatalf("Empty result set for query %s", sql) + } + + err = rows.Scan(&username) + if err != nil { + t.Fatalf("Failed to fetch first row for %s: %v", sql, err) + } + + if !strings.EqualFold(username, expected) { + t.Fatalf("Expected username %s: actual: %s", expected, username) + } + + t.Logf("Logged in username %s matches expected %s", username, expected) +} + +func TestFedAuthWithUserAndPassword(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u := checkFedAuthUserPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingPassword(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u := checkFedAuthAppPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingCertificate(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u := checkFedAuthAppCertPath(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithSystemAssignedIdentity(t *testing.T) { + u, vmName := checkFedAuthVMSystemID(t) + mssql.SetLogger(testLogger{t}) + + checkLoggedInUser(vmName, u, t) +} + +func TestFedAuthWithUserAssignedIdentity(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u, uaName := checkFedAuthVMUserAssignedID(t) + + checkLoggedInUser(uaName, u, t) +} diff --git a/azuread/configuration.go b/azuread/configuration.go new file mode 100644 index 00000000..ab40c023 --- /dev/null +++ b/azuread/configuration.go @@ -0,0 +1,212 @@ +package azuread + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "strings" +) + +const ( + fedAuthActiveDirectoryPassword = "ActiveDirectoryPassword" + fedAuthActiveDirectoryIntegrated = "ActiveDirectoryIntegrated" + fedAuthActiveDirectoryMSI = "ActiveDirectoryMSI" + fedAuthActiveDirectoryApplication = "ActiveDirectoryApplication" +) + +// Federated authentication library affects the login data structure and message sequence. +const ( + // fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme + fedAuthLibraryLiveIDCompactToken = 0x00 + + // fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available + // without additional information provided during the login sequence. + fedAuthLibrarySecurityToken = 0x01 + + // fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the + // login sequence using the server SPN and STS URL provided by the server during login. + fedAuthLibraryADAL = 0x02 + + // fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies. + fedAuthLibraryReserved = 0x7F +) + +// Federated authentication ADAL workflow affects the mechanism used to authenticate. +const ( + // fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory + fedAuthADALWorkflowPassword = 0x01 + + // fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory + fedAuthADALWorkflowIntegrated = 0x02 + + // fedAuthADALWorkflowMSI uses the managed identity service to obtain a token + fedAuthADALWorkflowMSI = 0x03 +) + +type azureFedAuthConfig struct { + // The detected federated authentication library + fedAuthLibrary int + + // Service principal logins + clientID string + tenantID string + clientSecret string + certificate *x509.Certificate + privateKey *rsa.PrivateKey + + // ADAL workflows + adalWorkflow byte + user string + password string +} + +func validateParameters(params map[string]string) (p *azureFedAuthConfig, err error) { + p = &azureFedAuthConfig{ + fedAuthLibrary: fedAuthLibraryReserved, + } + + fedAuthWorkflow, _ := params["fedauth"] + if fedAuthWorkflow == "" { + return p, nil + } + + switch { + case strings.EqualFold(fedAuthWorkflow, fedAuthActiveDirectoryPassword): + p.fedAuthLibrary = fedAuthLibraryADAL + p.adalWorkflow = fedAuthADALWorkflowPassword + p.user, _ = params["user id"] + p.password, _ = params["password"] + + case strings.EqualFold(fedAuthWorkflow, fedAuthActiveDirectoryIntegrated): + // Active Directory Integrated authentication is not fully supported: + // you can only use this by also implementing an a token provider + // and supplying it via ActiveDirectoryTokenProvider in the Connection. + p.fedAuthLibrary = fedAuthLibraryADAL + p.adalWorkflow = fedAuthADALWorkflowIntegrated + + case strings.EqualFold(fedAuthWorkflow, fedAuthActiveDirectoryMSI): + // When using MSI, to request a specific client ID or user-assigned identity, + // provide the ID in the "ad client id" parameter + p.fedAuthLibrary = fedAuthLibraryADAL + p.adalWorkflow = fedAuthADALWorkflowMSI + p.clientID, _ = splitTenantAndClientID(params["user id"]) + + case strings.EqualFold(fedAuthWorkflow, fedAuthActiveDirectoryApplication): + p.fedAuthLibrary = fedAuthLibrarySecurityToken + + // Split the clientID@tenantID format + p.clientID, p.tenantID = splitTenantAndClientID(params["user id"]) + if p.clientID == "" || p.tenantID == "" { + return nil, errors.New("Must provide 'client id@tenant id' as username parameter when using ActiveDirectoryApplication authentication") + } + + p.clientSecret, _ = params["password"] + + pemPath, _ := params["clientcertpath"] + + if pemPath == "" && p.clientSecret == "" { + return nil, errors.New("Must provide 'password' parameter when using ActiveDirectoryApplication authentication without cert/key credentials") + } + + if pemPath != "" { + if p.certificate, p.privateKey, err = getFedAuthClientCertificate(pemPath, p.clientSecret); err != nil { + return nil, err + } + + p.clientSecret = "" + } + + default: + return nil, fmt.Errorf("Invalid federated authentication type '%s': expected %s, %s, %s or %s", + fedAuthWorkflow, fedAuthActiveDirectoryPassword, fedAuthActiveDirectoryMSI, + fedAuthActiveDirectoryApplication, fedAuthActiveDirectoryIntegrated) + } + + return p, nil +} + +func splitTenantAndClientID(user string) (string, string) { + // Split the user name into client id and tenant id at the @ symbol + at := strings.IndexRune(user, '@') + if at < 1 || at >= (len(user)-1) { + return user, "" + } + + return user[0:at], user[at+1:] +} + +func (p *azureFedAuthConfig) provideSecurityToken(ctx context.Context) (string, error) { + switch { + case p.certificate != nil && p.privateKey != nil: + return SecurityTokenFromCertificate(ctx, p.clientID, p.tenantID, p.certificate, p.privateKey) + case p.clientSecret != "": + return SecurityTokenFromSecret(ctx, p.clientID, p.tenantID, p.clientSecret) + } + + return "", errors.New("Client certificate and key, or client secret, required for service principal login") +} + +func (p *azureFedAuthConfig) provideActiveDirectoryToken(ctx context.Context, serverSPN, stsURL string) (string, error) { + switch p.adalWorkflow { + case fedAuthADALWorkflowPassword: + return ActiveDirectoryTokenFromPassword(ctx, serverSPN, stsURL, p.user, p.password) + case fedAuthADALWorkflowMSI: + return ActiveDirectoryTokenFromIdentity(ctx, serverSPN, stsURL, p.clientID) + } + + return "", fmt.Errorf("ADAL workflow id %d not supported", p.adalWorkflow) +} + +func getFedAuthClientCertificate(clientCertPath, clientCertPassword string) (certificate *x509.Certificate, privateKey *rsa.PrivateKey, err error) { + pemBytes, err := ioutil.ReadFile(clientCertPath) + if err != nil { + } + + var block, encryptedPrivateKey *pem.Block + var certificateBytes, privateKeyBytes []byte + + for block, pemBytes = pem.Decode(pemBytes); block != nil; block, pemBytes = pem.Decode(pemBytes) { + _, dekInfo := block.Headers["DEK-Info"] + switch { + case block.Type == "CERTIFICATE": + certificateBytes = block.Bytes + case block.Type == "RSA PRIVATE KEY" && dekInfo: + encryptedPrivateKey = block + case block.Type == "RSA PRIVATE KEY": + privateKeyBytes = block.Bytes + default: + return nil, nil, fmt.Errorf("PEM file %s contains unsupported block type %s", clientCertPath, block.Type) + } + } + + if len(certificateBytes) == 0 { + return nil, nil, fmt.Errorf("No certificate found in PEM file at path %s: %v", clientCertPath, err) + } + + certificate, err = x509.ParseCertificate(certificateBytes) + if err != nil { + return nil, nil, fmt.Errorf("Failed to parse certificate found in PEM file at path %s: %v", clientCertPath, err) + } + + if encryptedPrivateKey != nil { + privateKeyBytes, err = x509.DecryptPEMBlock(encryptedPrivateKey, []byte(clientCertPassword)) + if err != nil { + return nil, nil, fmt.Errorf("Failed to decrypt private key found in PEM file at path %s: %v", clientCertPath, err) + } + } + + if len(privateKeyBytes) == 0 { + return nil, nil, fmt.Errorf("No private key found in PEM file at path %s: %v", clientCertPath, err) + } + + privateKey, err = x509.ParsePKCS1PrivateKey(privateKeyBytes) + if err != nil { + return nil, nil, fmt.Errorf("Failed to parse private key found in PEM file at path %s: %v", clientCertPath, err) + } + + return +} diff --git a/azuread/configuration_test.go b/azuread/configuration_test.go new file mode 100644 index 00000000..5a56535d --- /dev/null +++ b/azuread/configuration_test.go @@ -0,0 +1,317 @@ +package azuread + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "io/ioutil" + "math/big" + "os" + "testing" + "time" +) + +func TestValidateParameters(t *testing.T) { + passphrase := "SuperSecret7" + certBlock, _, encryptedKeyBlock, err := generateTestCertAndKey(passphrase) + if err != nil { + t.Logf("Unable to generate certificate and keys: %v", err) + t.FailNow() + } + + pemFile, err := writePEMBlocksToFile([]*pem.Block{certBlock, encryptedKeyBlock}) + if err != nil { + t.Logf("Unable to write certificate and encrypted key to temporary file: %v", err) + t.FailNow() + } + + defer func() { + os.Remove(pemFile) + }() + + tests := []struct { + name string + params map[string]string + expected *azureFedAuthConfig + }{ + { + name: "no fed auth configured", + params: map[string]string{}, + expected: &azureFedAuthConfig{fedAuthLibrary: fedAuthLibraryReserved}, + }, + { + name: "application with cert/key", + params: map[string]string{ + "fedauth": "ActiveDirectoryApplication", + "user id": "service-principal-id@tenant-id", + "password": passphrase, + "clientcertpath": pemFile, + }, + expected: &azureFedAuthConfig{ + fedAuthLibrary: fedAuthLibrarySecurityToken, + clientID: "service-principal-id", + tenantID: "tenant-id", + certificate: &x509.Certificate{}, + privateKey: &rsa.PrivateKey{}, + }, + }, + { + name: "application with cert/key missing passphrase", + params: map[string]string{ + "fedauth": "ActiveDirectoryApplication", + "user id": "service-principal-id@tenant-id", + "clientcertpath": pemFile, + }, + expected: nil, + }, + { + name: "application with cert/key missing tenant id", + params: map[string]string{ + "fedauth": "ActiveDirectoryApplication", + "user id": "service-principal-id", + "password": passphrase, + "clientcertpath": pemFile, + }, + expected: nil, + }, + { + name: "application with secret", + params: map[string]string{ + "fedauth": "ActiveDirectoryApplication", + "user id": "service-principal-id@tenant-id", + "password": passphrase, + }, + expected: &azureFedAuthConfig{ + fedAuthLibrary: fedAuthLibrarySecurityToken, + clientID: "service-principal-id", + tenantID: "tenant-id", + clientSecret: passphrase, + }, + }, + { + name: "user with password", + params: map[string]string{ + "fedauth": "ActiveDirectoryPassword", + "user id": "azure-ad-user@example.com", + "password": "azure-ad-password", + }, + expected: &azureFedAuthConfig{ + fedAuthLibrary: fedAuthLibraryADAL, + adalWorkflow: fedAuthADALWorkflowPassword, + user: "azure-ad-user@example.com", + password: "azure-ad-password", + }, + }, + { + name: "managed identity without client id", + params: map[string]string{ + "fedauth": "ActiveDirectoryMSI", + }, + expected: &azureFedAuthConfig{ + fedAuthLibrary: fedAuthLibraryADAL, + adalWorkflow: fedAuthADALWorkflowMSI, + }, + }, + { + name: "managed identity with client id", + params: map[string]string{ + "fedauth": "ActiveDirectoryMSI", + "user id": "identity-client-id", + }, + expected: &azureFedAuthConfig{ + fedAuthLibrary: fedAuthLibraryADAL, + adalWorkflow: fedAuthADALWorkflowMSI, + clientID: "identity-client-id", + }, + }, + } + + for _, tst := range tests { + config, err := validateParameters(tst.params) + if tst.expected == nil { + if err == nil { + t.Errorf("No error returned when error expected in test case '%s'", tst.name) + } + continue + } + + if err != nil { + t.Errorf("Error returned when none expected in test case '%s': %v", tst.name, err) + continue + } + + if tst.expected.certificate != nil && config.certificate != nil { + config.certificate = tst.expected.certificate + } + + if tst.expected.privateKey != nil && config.privateKey != nil { + config.privateKey = tst.expected.privateKey + } + + if *config != *tst.expected { + t.Errorf("Captured parameters do not match in test case '%s'", tst.name) + } + } +} + +func TestGetFedAuthClientCertificate(t *testing.T) { + passphrase := "SuperSecret7" + certBlock, keyBlock, encryptedKeyBlock, err := generateTestCertAndKey(passphrase) + if err != nil { + t.Logf("Unable to generate certificate and keys: %v", err) + t.FailNow() + } + + expectValid := func(name string) func(*x509.Certificate, *rsa.PrivateKey, error) { + return func(cert *x509.Certificate, key *rsa.PrivateKey, err error) { + if err != nil { + t.Errorf("Error loading %s test case certificate and key: %v", name, err) + } else { + if cert == nil { + t.Errorf("Expected cert but found nil in %s test case", name) + } + + if key == nil { + t.Errorf("Expected key but found nil in %s test case", name) + } + } + } + } + + expectError := func(name string) func(*x509.Certificate, *rsa.PrivateKey, error) { + return func(cert *x509.Certificate, key *rsa.PrivateKey, err error) { + if err == nil { + t.Errorf("Did not get expected error while loading %s test case certificate and key", name) + } + } + } + + tests := []struct { + name string + blocks []*pem.Block + loadPassphrase string + verifier func(certificate *x509.Certificate, privateKey *rsa.PrivateKey, err error) + }{ + { + name: "valid unencrypted", + blocks: []*pem.Block{certBlock, keyBlock}, + loadPassphrase: "", + verifier: expectValid("unencrypted"), + }, + { + name: "valid encrypted", + blocks: []*pem.Block{certBlock, encryptedKeyBlock}, + loadPassphrase: passphrase, + verifier: expectValid("encrypted"), + }, + { + name: "empty", + blocks: []*pem.Block{}, + loadPassphrase: "", + verifier: expectError("empty"), + }, + { + name: "bogus block type", + blocks: []*pem.Block{&pem.Block{Type: "HOT GARBAGE", Bytes: []byte("HOTGARBAGE==")}}, + loadPassphrase: "", + verifier: expectError("bogus block type"), + }, + { + name: "bogus certificate", + blocks: []*pem.Block{&pem.Block{Type: "CERTIFICATE", Bytes: []byte("HOTGARBAGE==")}}, + loadPassphrase: "", + verifier: expectError("bogus certificate"), + }, + { + name: "no private key", + blocks: []*pem.Block{certBlock}, + loadPassphrase: "", + verifier: expectError("no private key"), + }, + { + name: "bogus private key", + blocks: []*pem.Block{certBlock, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: []byte("HOTGARBAGE==")}}, + loadPassphrase: "", + verifier: expectError("bogus private key"), + }, + { + name: "bogus encrypted private key", + blocks: []*pem.Block{certBlock, &pem.Block{Type: "RSA PRIVATE KEY", Headers: map[string]string{"DEK-Info": "AlsoGarbage"}, Bytes: []byte("HOTGARBAGE==")}}, + loadPassphrase: "", + verifier: expectError("bogus encrypted private key"), + }, + } + + for _, tst := range tests { + pemFile, err := writePEMBlocksToFile(tst.blocks) + if err != nil { + t.Logf("Unable to write PEM blocks for test case %s: %v", tst.name, err) + t.FailNow() + } + + func() { + defer func() { os.Remove(pemFile) }() + + cert, key, err := getFedAuthClientCertificate(pemFile, tst.loadPassphrase) + + tst.verifier(cert, key, err) + }() + } +} + +func generateTestCertAndKey(passphrase string) (*pem.Block, *pem.Block, *pem.Block, error) { + priv, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, nil, err + } + + keyBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)} + + encryptedKeyBlock, err := x509.EncryptPEMBlock(rand.Reader, keyBlock.Type, keyBlock.Bytes, []byte(passphrase), x509.PEMCipherAES256) + if err != nil { + return nil, nil, nil, err + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"go-mssqldb"}, + }, + NotBefore: time.Now().Add(-(time.Minute * 5)), + NotAfter: time.Now().Add(time.Hour * 24), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, nil, err + } + + certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: derBytes} + + return certBlock, keyBlock, encryptedKeyBlock, nil +} + +func writePEMBlocksToFile(blocks []*pem.Block) (string, error) { + f, err := ioutil.TempFile("", "go-mssql-azureauth-") + if err != nil { + return "", err + } + + for _, block := range blocks { + if err = pem.Encode(f, block); err != nil { + return "", err + } + } + + if err = f.Close(); err != nil { + return "", err + } + + return f.Name(), nil +} diff --git a/azuread/conn_str.go b/azuread/conn_str.go new file mode 100644 index 00000000..59268a03 --- /dev/null +++ b/azuread/conn_str.go @@ -0,0 +1,55 @@ +package azuread + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value +func splitConnectionStringURL(dsn string) (map[string]string, error) { + res := map[string]string{} + + u, err := url.Parse(dsn) + if err != nil { + return res, err + } + + if u.Scheme != "sqlserver" { + return res, fmt.Errorf("scheme %s is not recognized", u.Scheme) + } + + if u.User != nil { + res["user id"] = u.User.Username() + p, exists := u.User.Password() + if exists { + res["password"] = p + } + } + + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Host + } + + if len(u.Path) > 0 { + res["server"] = host + "\\" + u.Path[1:] + } else { + res["server"] = host + } + + if len(port) > 0 { + res["port"] = port + } + + query := u.Query() + for k, v := range query { + if len(v) > 1 { + return res, fmt.Errorf("key %s provided more than once", k) + } + res[strings.ToLower(k)] = v[0] + } + + return res, nil +} diff --git a/azuread/driver.go b/azuread/driver.go new file mode 100644 index 00000000..3794317f --- /dev/null +++ b/azuread/driver.go @@ -0,0 +1,65 @@ +package azuread + +import ( + "context" + "database/sql" + "database/sql/driver" + + mssql "github.com/denisenkom/go-mssqldb" +) + +// DriverName is the name used to register the driver +const DriverName = "azuresql" + +func init() { + sql.Register(DriverName, &Driver{}) +} + +// Driver wraps the underlying MSSQL driver, but configures the Azure AD token provider +type Driver struct { +} + +// Open returns a new connection to the database. +func (d *Driver) Open(dsn string) (driver.Conn, error) { + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + + return c.Connect(context.Background()) +} + +// NewConnector creates a new connector from a DSN. +// The returned connector may be used with sql.OpenDB. +func NewConnector(dsn string) (*mssql.Connector, error) { + params, err := splitConnectionStringURL(dsn) + if err != nil { + return nil, err + } + + config, err := validateParameters(params) + if err != nil { + return nil, err + } + + switch config.fedAuthLibrary { + case fedAuthLibrarySecurityToken: + return mssql.NewSecurityTokenConnector( + dsn, + func(ctx context.Context) (string, error) { + return config.provideSecurityToken(ctx) + }, + ) + + case fedAuthLibraryADAL: + return mssql.NewActiveDirectoryTokenConnector( + dsn, config.adalWorkflow, + func(ctx context.Context, serverSPN, stsURL string) (string, error) { + return config.provideActiveDirectoryToken(ctx, serverSPN, stsURL) + }, + ) + + default: + return mssql.NewConnector(dsn) + } +} diff --git a/go.mod b/go.mod index ebc02ab8..67fcf4dd 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/denisenkom/go-mssqldb go 1.11 require ( + github.com/Azure/go-autorest/autorest/adal v0.9.9 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe - golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c + golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 ) diff --git a/go.sum b/go.sum index 1887801b..6e0ef725 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,21 @@ +github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= +github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= +github.com/Azure/go-autorest/autorest/adal v0.9.9 h1:y/DT2jMCd/Bme1PJzdp5OtiE16LznXG4YSlcNBqW4Us= +github.com/Azure/go-autorest/autorest/adal v0.9.9/go.mod h1:B7KF7jKIeC9Mct5spmyCB/A8CG/sEz1vwIRGv/bbw7A= +github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw= +github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= +github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk= +github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= +github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= +github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 h1:hb9wdF1z5waM+dSIICn1l0DkLVDT3hqhhQsDNUmHPRE= +golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= From acdd0dd200e49927803bf66734364ed51dda7093 Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Mon, 14 Dec 2020 00:25:31 -0800 Subject: [PATCH 4/6] Skip test for connection close that does not work when TLS is applied --- queries_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/queries_test.go b/queries_test.go index f89f9013..c6b3ee08 100644 --- a/queries_test.go +++ b/queries_test.go @@ -1300,6 +1300,11 @@ func TestProcessQueryErrors(t *testing.T) { } func TestProcessQueryNextErrors(t *testing.T) { + params := testConnParams(t) + if params.encrypt { + t.Skip("Unable to test connection close as TLS wrapping hides underlying socket") + } + conn := internalConnection(t) statements := make([]string, 1000) for i := 0; i < len(statements); i++ { From 2b05e757f94e3d69bfb17f38f3d4dc7170d666b4 Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Mon, 14 Dec 2020 10:50:06 -0800 Subject: [PATCH 5/6] Skip Azure AD for Go 1.8 --- appveyor.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/appveyor.yml b/appveyor.yml index dfcb62de..95766126 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -12,8 +12,10 @@ environment: SQLPASSWORD: Password12! DATABASE: test GOVERSION: 113 + ADALSUPPORT: test matrix: - GOVERSION: 18 + ADALSUPPORT: no-test SQLINSTANCE: SQL2017 - GOVERSION: 19 SQLINSTANCE: SQL2017 @@ -46,6 +48,7 @@ install: - go version - go env - go get -u github.com/golang-sql/civil + - if %ADALSUPPORT%==test go get -u github.com/Azure/go-autorest/autorest/adal build_script: - go build From df5b6c52065100ece8f1a8fd41937d3afded83e3 Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Sun, 13 Dec 2020 18:46:59 -0800 Subject: [PATCH 6/6] Update documentation for Azure AD authentication and add examples --- README.md | 96 +++- doc/how-to-test-azure-ad-authentication.md | 178 ++++++++ examples/azuread/.gitignore | 2 + examples/azuread/azuread.go | 144 ++++++ examples/azuread/dsn-variables.jq | 10 + examples/azuread/environment-settings.jq | 20 + examples/azuread/go.mod | 10 + examples/azuread/testing.tf | 506 +++++++++++++++++++++ examples/simple/simple.go | 43 +- examples/tvp/tvp.go | 2 + 10 files changed, 982 insertions(+), 29 deletions(-) create mode 100644 doc/how-to-test-azure-ad-authentication.md create mode 100644 examples/azuread/.gitignore create mode 100644 examples/azuread/azuread.go create mode 100644 examples/azuread/dsn-variables.jq create mode 100644 examples/azuread/environment-settings.jq create mode 100644 examples/azuread/go.mod create mode 100644 examples/azuread/testing.tf diff --git a/README.md b/README.md index 94d87fe0..36fc694a 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,11 @@ Other supported formats are listed below. * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. * `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. * `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. -* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. +* `ServerSPN` - The Kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. * `Workstation ID` - The workstation name (default is the host name) * `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. + ### The connection string can be specified in one of three formats: @@ -106,25 +107,88 @@ Other supported formats are listed below. * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar" -### Azure Active Directory authentication - preview +### Azure Active Directory authentication -The configuration of functionality might change in the future. +Azure Active Directory authentication uses temporary authentication tokens to authenticate. +The `mssql` package does not provide an implementation to obtain tokens: instead, import the +`azuread` package and use driver name `azuresql`. This driver uses the +[Active Directory Authentication Library for Go](https://github.com/Azure/go-autorest/tree/master/autorest/adal) +to obtain Azure Active Directory authentication tokens. -Azure Active Directory (AAD) access tokens are relatively short lived and need to be -valid when a new connection is made. Authentication is supported using a callback func that -provides a fresh and valid token using a connector: -``` golang -conn, err := mssql.NewAccessTokenConnector( - "Server=test.database.windows.net;Database=testdb", - tokenProvider) -if err != nil { - // handle errors in DSN +Authentication using Active Directory is enabled using the `fedauth` connection parameter, +in combination with the `user id` and `password` (or URL username and password). + + * `fedauth=ActiveDirectoryApplication` - authenticates using an Azure Active Directory application client ID and client secret or certificate. + + Set the `user id` to `clientID@tenantID` for your service principal. If using a client secret, set the `password` to the client secret. If using client certificates, provide the path to the PEM file containing the certificate concatenated with the RSA private key in the `clientcertpath` parameter, and set the `password` to the passphrase needed to decrypt the RSA private key (omit or leave blank if unencrypted). + + * `fedauth=ActiveDirectoryMSI` - authenticates using the managed service identity (MSI) attached to the VM (system identity), or a specific user-assigned identity. + + To select a user-assigned identity, specify a client ID in the `user id` parameter. For the system-assigned identity, leave the `user id` empty. + + * `fedauth=ActiveDirectoryPassword` - authenticates an Azure Active Directory user account. + + Set the `user id` to `user@domain.com` and the `password`. This method is not recommended for general use and does not support multi-factor authentication for accounts. + + +```golang +import ( + "database/sql" + "net/url" + + // Import the Azure AD driver module (also imports the regular driver package) + "github.com/denisenkom/go-mssqldb/azuread" +) + +func ConnectWithMSI() (*sql.DB, error) { + return sql.Open(azuread.DriverName, "sqlserver://azuresql.database.windows.net?database=yourdb&fedauth=ActiveDirectoryMSI") +} +``` + +As an alternative, you can select the federated authentication library and Active Directory +using the connection string parameters, but then implement your own routine for obtaining +tokens. The second example shows how this could be used to add in a token for the Azure AD +Integrated authentication scenario. + +```golang +import ( + "context" + "database/sql" + "net/url" + + // Import the driver + "github.com/denisenkom/go-mssqldb" +) + +func ConnectWithSecurityToken() (*sql.DB, error) { + conn, err := mssql.NewSecurityTokenConnector( + "sqlserver://azuresql.database.windows.net?database=yourdb", + func(ctx context.Context) (string, error) { + return "the token", nil + }, + ) + if err != nil { + // handle errors in DSN + } + + return sql.OpenDB(conn), nil +} + +func ConnectWithADIntegrated() (*sql.DB, error) { + conn, err := mssql.NewActiveDirectoryTokenConnector( + "sqlserver://azuresq;.database.windows.net?database=yourdb", + 2, // Active Directory workflow: 1 = user/password, 2 = integrated, 3 = MSI + func(ctx context.Context, serverSPN, stsURL string) (string, error) { + return "the token", nil + }, + ) + if err != nil { + // handle errors in DSN + } + + return sql.OpenDB(conn), nil } -db := sql.OpenDB(conn) ``` -Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements -actually trigger the retrieval of a token, this happens when the first statment is issued and a connection -is created. ## Executing Stored Procedures diff --git a/doc/how-to-test-azure-ad-authentication.md b/doc/how-to-test-azure-ad-authentication.md new file mode 100644 index 00000000..12574506 --- /dev/null +++ b/doc/how-to-test-azure-ad-authentication.md @@ -0,0 +1,178 @@ +# How to test Azure AD authentication + +To test Azure AD authentication requires an Azure SQL server configured with an +[Active Directory administrator](https://docs.microsoft.com/en-us/azure/sql-database/sql-database-aad-authentication-configure). +To test managed identity authentication, an Azure virtual machine configured with +[system-assigned and/or user-assigned identities](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/qs-configure-portal-windows-vm) +is also required. + +The necessary resources can be set up through any means including the +[Azure Portal](https://portal.azure.com/), the Azure CLI, the Azure PowerShell cmdlets or +[Terraform](https://terraform.io/). To support these instructions, use the Terraform script at +[examples/azuread/testing.tf](../examples/azuread/testing.tf). + +## Create Azure infrastructure + +Download [Terraform](https://terraform.io/) to a location on your PATH. + +Log in to Azure using the Azure CLI. + +```console +you@workstation:~$ az login +you@workstation:~$ az account show +``` + +If your Azure account has access to multiple subscriptions, use +`az account set --subscription ` to choose the correct one. You will need to have at +least Contributor access to the portal and permissions in Azure Active Directory to create users +and grants. + +Check out this source repository (if you haven't already), change to the `examples/azuread` +directory and run Terraform: + +```console +you@workstation:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +you@workstation:~$ cd go-mssqldb/examples/azuread +you@workstation:azuread$ terraform init +you@workstation:azuread$ terraform apply +``` + +This will create an Azure resource group, a SQL server with a database, a virtual machine with a +system-assigned identity and user-assigned identity. Resources are named based on a random +prefix: to specify the prefix, use `terraform apply -var prefix=`. + +Upon successful completion, Terraform will display some key details of the infrastructure that has + been created. This includes the SSH key to access the VM, the administrator account and password + for the Azure SQL server, and all the relevant resource names. + +Save the settings to a JSON file: + +```console +you@workstation:azuread$ terraform output -json > settings.json +``` + +Save the SSH private key to a file: + +```console +you@workstation:azuread$ terraform output vm_user_ssh_private_key > ssh-identity +``` + +Copy the `settings.json` to the new VM: + +```console +you@workstation:azuread$ eval "VM_ADMIN_NAME=$(terraform output vm_admin_name)" +you@workstation:azuread$ eval "VM_IP_ADDRESS=$(terraform output vm_ip_address)" +you@workstation:azuread$ scp -i ssh-identity settings.json "${VM_ADMIN_NAME}@${VM_IP_ADDRESS}:" +``` + +## Set up Azure Virtual Machine for testing + +SSH to the new VM to continue setup: + +```console +you@workstation:azuread$ ssh -i ssh-identity "${VM_ADMIN_NAME}@${VM_IP_ADDRESS}" +``` + +Once on the VM, update the system and install some basic packages: + +```console +azureuser@azure-vm:~$ sudo apt update -y +azureuser@azure-vm:~$ sudo apt upgrade -y +azureuser@azure-vm:~$ sudo apt install -y git openssl jq build-essential +azureuser@azure-vm:~$ sudo snap install go --classic +``` + +Install the Azure CLI using the script as shown below, or follow the +[manual install instructions](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli-apt): + +```console +azureuser@azure-vm:~$ curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash +``` + +## Generate service principal certificate file + +Log in to Azure on the VM and set the subscription: + +```console +azureuser@azure-vm:~$ az login +azureuser@azure-vm:~$ az account set --subscription "$(jq -r '.subscription_id.value' settings.json)" +``` + +Use OpenSSL to create a new certificate and key in PEM format, using the : + +```console +azureuser@azure-vm:~$ openssl rand -writerand ~/.rnd +azureuser@azure-vm:~$ openssl req -x509 -nodes -newkey rsa:4096 -keyout client.key -out client.crt \ + -subj "/C=US/ST=MA/L=Boston/O=Global Security/OU=IT Department/CN=AD-SP" +azureuser@azure-vm:~$ openssl rsa -out client.pem -in client.key -aes256 \ + -passout "pass:$(jq -r '.app_sp_client_secret.value' settings.json)" +azureuser@azure-vm:~$ cat client.crt >> client.pem +azureuser@azure-vm:~$ export APP_SP_CLIENT_CERT="$PWD/client.pem" +``` + +Use the Azure CLI to add the client certificate to the application service principal: + +```console +azureuser@azure-vm:~$ az ad sp credential reset --append --cert @client.crt \ + --name "$(jq -r '.app_sp_client_id.value' settings.json)" +``` + +## Build source code and authorize users in database + +Clone this repository, build and run the `examples/azuread` helper that verifies the database +exists and sets up access for the system-assigned and user-assigned identities. + +```console +azureuser@azure-vm:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +azureuser@azure-vm:~$ cd go-mssqldb +azureuser@azure-vm:go-mssqldb$ (cd ./examples/azuread; go build -o ../../azuread-example .) +azureuser@azure-vm:go-mssqldb$ eval "$(jq -r -f examples/azuread/environment-settings.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ ./azuread-example -fedauth ActiveDirectoryPassword +``` + +For some basic connectivity tests, use the `examples/simple` helper. Run these commands on the +Azure VM so that identity authentication is possible. + +```console +azureuser@azure-vm:go-mssqldb$ eval "$(jq -r --arg certpath "$(realpath ../client.pem)" -f examples/azuread/dsn-variables.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ go build -o simple ./examples/simple +azureuser@azure-vm:go-mssqldb$ ./simple -debug -dsn "$AD_APP_CERT_DSN" +azureuser@azure-vm:go-mssqldb$ ./simple -debug -dsn "$AD_APP_PWD_DSN" +azureuser@azure-vm:go-mssqldb$ ./simple -debug -dsn "$AD_MSI_SYS_DSN" +azureuser@azure-vm:go-mssqldb$ ./simple -debug -dsn "$AD_MSI_USER_DSN" +azureuser@azure-vm:go-mssqldb$ ./simple -debug -dsn "$AD_USER_PWD_DSN" +azureuser@azure-vm:go-mssqldb$ ./simple -debug -dsn "$SQL_USER_PWD_DSN" +``` + +## Running the integration tests + +Now that your environment is configured, you can run `go test`: + +```console +azureuser@azure-vm:go-mssqldb$ export SQLSERVER_DSN="$AD_APP_CERT_DSN" +azureuser@azure-vm:go-mssqldb$ go test -coverprofile=coverage.out . ./azuread ./batch ./internal/... +azureuser@azure-vm:go-mssqldb$ go tool cover -html=coverage.out -o coverage.html +``` + +## Tear down environment + +After you complete your testing, use Terraform to destroy the infrastructure you created. + +```console +you@workstation:azuread$ terraform destroy +``` + +## Troubleshooting + +After Terraform runs you should be able to see resources that were created in the +[Azure Portal](https://portal.azure.com/). + +If the Azure SQL server is successfully created you can connect to it using the AD admin user +and password in SSMS. SSMS will prompt you to create firewall rules if they are missing. You +can read the AD admin user and password from the `settings.json`, or run: + +```console +you@workstation:azuread$ terraform output sql_ad_admin_user +you@workstation:azuread$ terraform output sql_ad_admin_password +``` + diff --git a/examples/azuread/.gitignore b/examples/azuread/.gitignore new file mode 100644 index 00000000..e2a8f423 --- /dev/null +++ b/examples/azuread/.gitignore @@ -0,0 +1,2 @@ +settings.json +ssh-identity diff --git a/examples/azuread/azuread.go b/examples/azuread/azuread.go new file mode 100644 index 00000000..3833e794 --- /dev/null +++ b/examples/azuread/azuread.go @@ -0,0 +1,144 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + "net/url" + "os" + "strings" + "time" + + "github.com/denisenkom/go-mssqldb/azuread" +) + +var ( + debug = flag.Bool("debug", false, "enable debugging") + server = flag.String("server", os.Getenv("SQL_SERVER"), "the database server name") + port = flag.Int("port", 1433, "the database port") + database = flag.String("database", os.Getenv("SQL_DATABASE"), "the database name") + user = flag.String("user", os.Getenv("SQL_AD_ADMIN_USER"), "the AD administrator user name") + password = flag.String("password", os.Getenv("SQL_AD_ADMIN_PASSWORD"), "the AD administrator password") + fedauth = flag.String("fedauth", "ActiveDirectoryPassword", "the federated authentication scheme to use") + appName = flag.String("app-name", os.Getenv("APP_NAME"), "the application name to authorize") + vmName = flag.String("vm-name", os.Getenv("VM_NAME"), "the system identity name to authorize for this VM") + uaName = flag.String("ua-name", os.Getenv("UA_NAME"), "the user assigned identity name to authorize for this VM") +) + +func createConnStr(database string) string { + connString := fmt.Sprintf("sqlserver://%s:%s@%s:%d?encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port) + + if database != "" && database != "master" { + connString = connString + "&database=" + url.QueryEscape(database) + } + + if *fedauth != "" { + connString = connString + "&fedauth=" + url.QueryEscape(*fedauth) + } + + if *debug { + connString = connString + "&log=127" + } + + return connString +} + +func createDatabaseIfNotExists() error { + // Check database exists by connecting to master on the Azure SQL server + connString := createConnStr("master") + + log.Printf("Open: %s\n", connString) + + conn, err := sql.Open(azuread.DriverName, connString) + if err != nil { + return err + } + + defer conn.Close() + + if err = conn.Ping(); err != nil { + return err + } + + quoted := strings.Replace(*database, "]", "]]", -1) + sql := "IF NOT EXISTS (SELECT 1 FROM sys.databases WHERE name = @p1)\n CREATE DATABASE [" + quoted + "] ( SERVICE_OBJECTIVE = 'S0' )" + log.Printf("Exec: @p1 = '%s'\n%s\n", *database, sql) + _, err = conn.Exec(sql, *database) + + return err +} + +func addExternalUserIfNotExists(user string) error { + connString := createConnStr(*database) + + log.Printf("Open: %s\n", connString) + + var conn *sql.DB + var err error + + for retry := 0; retry < 8; retry++ { + conn, err = sql.Open(azuread.DriverName, connString) + if err == nil { + if err = conn.Ping(); err == nil { + break + } + } + log.Printf("Connection failed: %v", err) + log.Println("Retry in 15 seconds") + time.Sleep(15 * time.Second) + } + if err != nil { + log.Printf("Connection failed: %v", err) + log.Println("No further retries will be attempted") + return err + } + + defer conn.Close() + + quoted := strings.Replace(user, "]", "]]", -1) + sql := "IF NOT EXISTS (SELECT 1 FROM sys.database_principals WHERE name = @p1)\n CREATE USER [" + quoted + "] FROM EXTERNAL PROVIDER" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + if err != nil { + return err + } + + sql = "IF IS_ROLEMEMBER('db_owner', @p1) = 0\n ALTER ROLE [db_owner] ADD MEMBER [" + quoted + "]" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + + return err +} + +func main() { + flag.Parse() + + err := createDatabaseIfNotExists() + if err != nil { + log.Fatalf("Unable to create database [%s]: %v", *database, err) + } + + if *vmName != "" { + err = addExternalUserIfNotExists(*vmName) + if err != nil { + log.Fatalf("Unable to create user for system-assigned identity [%s]: %v", *vmName, err) + } + } + + if *appName != "" { + err = addExternalUserIfNotExists(*appName) + if err != nil { + log.Fatalf("Unable to create user for application identity [%s]: %v", *appName, err) + } + } + + if *uaName != "" { + err = addExternalUserIfNotExists(*uaName) + if err != nil { + log.Fatalf("Unable to create user for user-assigned identity [%s]: %v", *uaName, err) + } + } +} diff --git a/examples/azuread/dsn-variables.jq b/examples/azuread/dsn-variables.jq new file mode 100644 index 00000000..810dcb09 --- /dev/null +++ b/examples/azuread/dsn-variables.jq @@ -0,0 +1,10 @@ +[ + "set -a", + "AD_APP_CERT_DSN=" + (@uri "sqlserver://\(.app_sp_client_id.value)%40\(.tenant_id.value):\(.app_sp_client_secret.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryApplication&clientcertpath=\($certpath)" | @sh), + "AD_APP_PWD_DSN=" + (@uri "sqlserver://\(.app_sp_client_id.value)%40\(.tenant_id.value):\(.app_sp_client_secret.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryApplication" | @sh), + "AD_MSI_SYS_DSN=" + (@uri "sqlserver://\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" | @sh), + "AD_MSI_USER_DSN=" + (@uri "sqlserver://\(.user_assigned_identity_client_id.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" | @sh), + "AD_USER_PWD_DSN=" + (@uri "sqlserver://\(.sql_ad_admin_user.value):\(.sql_ad_admin_password.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryPassword" | @sh), + "SQL_USER_PWD_DSN=" + (@uri "sqlserver://\(.sql_admin_user.value):\(.sql_admin_password.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true" | @sh), + "set +a" +] | map([.]) | .[] | @tsv diff --git a/examples/azuread/environment-settings.jq b/examples/azuread/environment-settings.jq new file mode 100644 index 00000000..a8c9192a --- /dev/null +++ b/examples/azuread/environment-settings.jq @@ -0,0 +1,20 @@ +# Convert Terraform settings to shell environment exports. +[ + "set -a", + "SQL_SERVER=" + (.sql_server_fqdn.value | @sh), + "SQL_ADMIN_USER=" + (.sql_admin_user.value | @sh), + "SQL_ADMIN_PASSWORD=" + (.sql_admin_password.value | @sh), + "SQL_AD_ADMIN_USER=" + (.sql_ad_admin_user.value | @sh), + "SQL_AD_ADMIN_PASSWORD=" + (.sql_ad_admin_password.value | @sh), + "APP_SP_CLIENT_ID=" + (.app_sp_client_id.value | @sh), + "APP_SP_CLIENT_SECRET=" + (.app_sp_client_secret.value | @sh), + "SQL_DATABASE=" + (.sql_database_name.value | @sh), + "APP_NAME=" + (.app_name.value | @sh), + "VM_NAME=" + (.vm_name.value | @sh), + "VM_CLIENT_ID=" + (.vm_client_id.value | @sh), + "UA_NAME=" + (.user_assigned_identity_name.value | @sh), + "UA_CLIENT_ID=" + (.user_assigned_identity_client_id.value | @sh), + "AZURE_SUBSCRIPTION_ID=" + (.subscription_id.value | @sh), + "AZURE_TENANT_ID=" + (.tenant_id.value | @sh), + "set +a" +] | map([.]) | .[] | @tsv diff --git a/examples/azuread/go.mod b/examples/azuread/go.mod new file mode 100644 index 00000000..217d12b0 --- /dev/null +++ b/examples/azuread/go.mod @@ -0,0 +1,10 @@ +module github.com/denisenkom/go-mssqldb/examples/azuread + +go 1.13 + +require ( + github.com/Azure/go-autorest/autorest/adal v0.8.1 + github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73 +) + +replace github.com/denisenkom/go-mssqldb => ../.. \ No newline at end of file diff --git a/examples/azuread/testing.tf b/examples/azuread/testing.tf new file mode 100644 index 00000000..488df9a0 --- /dev/null +++ b/examples/azuread/testing.tf @@ -0,0 +1,506 @@ +# +# Terraform setup for Azure SQL with Azure Active Directory authentication +# + +# +# Set up Terraform provider versions +# + +terraform { + required_providers { + azuread = { + source = "hashicorp/azuread" + version = "=1.1.1" + } + + azurerm = { + source = "hashicorp/azurerm" + version = "=2.40.0" + } + + http = { + source = "hashicorp/http" + version = "=2.0.0" + } + + random = { + version = "=3.0.0" + } + + tls = { + version = "=3.0.0" + } + } +} + +provider "azurerm" { + features {} +} + +# +# Variables +# +# These variables allow limited overrides to control the resource creation. +# To specify, run terraform apply -var name1=value1 [-var name2=value2]... +# E.g. terraform apply -var prefix=my-stuff +# will use "my-stuff" in place of the randomly generated ID that is used by default. +# +variable "prefix" { + description = "Prefix for Azure resource names" + type = string + default = "" +} + +variable "location" { + description = "Azure location for resources" + type = string + default = "East US" +} + +variable "vm_admin_name" { + description = "Name of administrative user on virtual machine" + type = string + default = "azureuser" +} + +variable "ssh_key" { + description = "Path to RSA SSH private key (unencrypted)" + type = string + default = "~/.ssh/id_rsa" +} + +variable "workstation_ip" { + description = "IP address of this workstation to add to SQL server firewall rules" + type = string + default = "" +} + +# +# If the prefix is not specified via the variable, a sixteen character alphanumeric suffix is +# generated and then the prefix is set to "go-mssql-test-" + +# +resource "random_string" "random_prefix" { + length = 16 + lower = true + number = true + upper = false + special = false +} + +# +# Set up a local variable to capture the prefix to use - either the user-specified from the +# variable, or else the generated name using the random string above. +# +# Some resource names (e.g. SQL server) are more restricted than others - e.g. hyphens are +# not permitted - so we create a restricted name prefix as well as a regular name prefix. +# +locals { + regular_name_prefix = var.prefix != "" ? var.prefix : "go-mssql-test-${random_string.random_prefix.result}" + restricted_name_prefix = var.prefix != "" ? lower(replace(var.prefix, "/[^A-Za-z0-9]/", "")) : "gomssqltest${random_string.random_prefix.result}" +} + +# +# SSH Key - generate if not available at the file named in the variable. +# Terraform will complain if var.ssh_key is empty as this is interpreted as referring to the +# current working directory, and that is not a file. Instead, if you want to avoid using an +# existing SSH key, make it a literal "no" or some other string that is not an existing file or +# directory. +# +data "tls_public_key" "file_ssh_key" { + count = fileexists(var.ssh_key) ? 1 : 0 + private_key_pem = fileexists(var.ssh_key) ? file(var.ssh_key) : "" +} + +resource "tls_private_key" "rand_ssh_key" { + algorithm = "ECDSA" +} + +locals { + private_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.private_key_pem : tls_private_key.rand_ssh_key.private_key_pem + public_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_pem : tls_private_key.rand_ssh_key.public_key_pem + public_key_openssh = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_openssh : tls_private_key.rand_ssh_key.public_key_openssh +} + +# +# Retrieve tenant, subscription and default domain information based on the current Azure login. +# +data "azurerm_client_config" "current" { +} + +data "azurerm_subscription" "current" { +} + +data "azuread_domains" "current" { + only_default = "true" +} + +# +# Use ipify.org to determine workstation IP if not provided. +# If this guesses incorrectly, specify your workstation IP with -var worstation_ip= +# when you run terraform apply. +# +data "http" "workstation_ip" { + url = "https://api.ipify.org/" +} + +locals { + workstation_ip = var.workstation_ip != "" ? var.workstation_ip : chomp(data.http.workstation_ip.body) +} + +# +# Set up the Azure resource group for all the test resources. +# +resource "azurerm_resource_group" "rg" { + name = "${local.regular_name_prefix}-rg" + location = var.location +} + +# +# Set up an AD User to use as AD Administrator for the Azure SQL server. +# +# Using a regular user account makes it simpler to log in as the user with SSMS or the Go +# driver when setting up the other permissions for the identities that will be tested. +# It appears to although you can make the AD Administrator a service principal, doing so +# leads to issues during logins that do not occur when the AD Administrator is a normal +# AD User account. +# +resource "random_password" "sql_ad_admin_sp_password" { + length = 32 + special = true +} + +resource "azuread_user" "sql_ad_admin" { + user_principal_name = "SQLAdmin.${local.restricted_name_prefix}@${data.azuread_domains.current.domains[0].domain_name}" + display_name = "SQL Admin for ${local.restricted_name_prefix}" + mail_nickname = "SQLAdmin.${local.restricted_name_prefix}" + password = random_password.sql_ad_admin_sp_password.result +} + +# +# Set up the Azure SQL Server +# +# A normal (non-AD) administrator username and password are also provisioned. However, it is +# not possible to create AD users without logging in via an AD-authenticated account, so this +# non-AD administrator is not able to create new AD user accounts. +# +resource "random_password" "sql_admin_password" { + length = 16 + special = true +} + +resource "azurerm_sql_server" "sql_server" { + name = local.restricted_name_prefix + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + version = "12.0" + administrator_login = "sql-admin" + administrator_login_password = random_password.sql_admin_password.result +} + +resource "azurerm_sql_active_directory_administrator" "sql_server" { + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + login = "sql-ad-admin" + tenant_id = data.azurerm_client_config.current.tenant_id + object_id = azuread_user.sql_ad_admin.id +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_azure" { + name = "AllowAzureAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = "0.0.0.0" + end_ip_address = "0.0.0.0" +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_workstation" { + name = "AllowWorkstationAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = local.workstation_ip + end_ip_address = local.workstation_ip +} + +# +# Set up the test database on the Azure SQL server +# +resource "azurerm_sql_database" "sql_db" { + name = "go-mssqldb" + + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + location = azurerm_sql_server.sql_server.location + + requested_service_objective_name = "S0" +} + +# +# Create a service principal that will be granted access to the database, +# representing an application login to the database. +# +resource "azuread_application" "app" { + name = "${local.regular_name_prefix}-app" +} + +resource "azuread_service_principal" "app_sp" { + application_id = azuread_application.app.application_id + app_role_assignment_required = false +} + +resource "random_password" "app_sp_password" { + length = 32 + special = true +} + +resource "azuread_service_principal_password" "app_sp" { + service_principal_id = azuread_service_principal.app_sp.id + value = random_password.app_sp_password.result + end_date_relative = "8760h" +} + + +# +# Create a user-assigned identity that we will add to the VM in addition to the +# system-assigned identity. +# +resource "azurerm_user_assigned_identity" "vm_user_id" { + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + name = "${local.restricted_name_prefix}-user-id" +} + +# +# Create an Azure VM for testing managed identity authentication. +# +# To support the Azure VM, we need a virtual network, a subnet, the public IP, the network +# security group, and the network interface. The network security group allows incoming SSH +# from the anywhere on the internet. +# +resource "azurerm_virtual_network" "vm_vnet" { + name = "${local.regular_name_prefix}-vnet" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + address_space = ["10.0.0.0/16"] +} + +resource "azurerm_subnet" "vm_subnet" { + name = "${local.regular_name_prefix}-vm-sn" + resource_group_name = azurerm_resource_group.rg.name + virtual_network_name = azurerm_virtual_network.vm_vnet.name + address_prefixes = ["10.0.2.0/24"] +} + +resource "azurerm_public_ip" "vm_ip" { + name = "${local.regular_name_prefix}-vm-ip" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + allocation_method = "Dynamic" + idle_timeout_in_minutes = 30 +} + +resource "azurerm_network_security_group" "vm_nsg" { + name = "${local.regular_name_prefix}-vm-nsg" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + security_rule { + name = "SSH" + priority = 1001 + direction = "Inbound" + access = "Allow" + protocol = "Tcp" + source_port_range = "*" + destination_port_range = "22" + source_address_prefix = "*" + destination_address_prefix = "*" + } +} + +resource "azurerm_network_interface" "vm_nic" { + name = "${local.regular_name_prefix}-vm-nic" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + ip_configuration { + name = "${local.regular_name_prefix}-vm-nic-config" + subnet_id = azurerm_subnet.vm_subnet.id + private_ip_address_allocation = "Dynamic" + public_ip_address_id = azurerm_public_ip.vm_ip.id + } +} + +resource "azurerm_network_interface_security_group_association" "vm_nic_nsg" { + network_interface_id = azurerm_network_interface.vm_nic.id + network_security_group_id = azurerm_network_security_group.vm_nsg.id +} + +# +# Given the networking setup, now create the Azure VM +# +resource "azurerm_virtual_machine" "vm" { + name = "${local.regular_name_prefix}-vm" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + network_interface_ids = [azurerm_network_interface.vm_nic.id] + vm_size = "Standard_B1s" + + storage_os_disk { + name = "${local.regular_name_prefix}-vm-os" + caching = "ReadWrite" + create_option = "FromImage" + managed_disk_type = "Standard_LRS" + } + + storage_image_reference { + publisher = "Canonical" + offer = "UbuntuServer" + sku = "18.04-LTS" + version = "latest" + } + + os_profile { + computer_name = "${local.regular_name_prefix}-vm" + admin_username = var.vm_admin_name + } + + os_profile_linux_config { + disable_password_authentication = true + ssh_keys { + path = "/home/${var.vm_admin_name}/.ssh/authorized_keys" + key_data = local.public_key_openssh + } + } + + # Configure the VM with both SystemAssigned and a UserAssigned identity + identity { + type = "SystemAssigned, UserAssigned" + identity_ids = [azurerm_user_assigned_identity.vm_user_id.id] + } +} + +# Retrieve the application ID corresponding to the service principal ID assigned to the VM. +data "azuread_service_principal" "vm_sp" { + object_id = azurerm_virtual_machine.vm.identity.0.principal_id +} + +# Wait for public IP to be assigned after VM is created so we can report it in the outputs. +data "azurerm_public_ip" "vm_ip" { + name = azurerm_public_ip.vm_ip.name + resource_group_name = azurerm_virtual_machine.vm.resource_group_name +} + +# +# After provisioning or refreshing, Terraform will populate these outputs. +# These capture the necessary pieces of information to access the new infrastructure. +# +output "tenant_id" { + description = "Azure tenant ID" + value = data.azurerm_client_config.current.tenant_id +} + +output "subscription_id" { + description = "Azure subscription ID" + value = data.azurerm_client_config.current.subscription_id +} + +output "sql_server_name" { + description = "Azure SQL server name" + value = azurerm_sql_server.sql_server.name +} + +output "sql_server_fqdn" { + description = "Azure SQL server domain name" + value = azurerm_sql_server.sql_server.fully_qualified_domain_name +} + +output "sql_ad_admin_user" { + description = "Azure SQL administrator name (AD authentication)" + value = azuread_user.sql_ad_admin.user_principal_name +} + +output "sql_ad_admin_password" { + description = "Azure SQL administrator password (AD authentication)" + value = random_password.sql_ad_admin_sp_password.result + sensitive = true +} + +output "sql_admin_user" { + description = "Azure SQL administrator name (SQL server authentication)" + value = azurerm_sql_server.sql_server.administrator_login +} + +output "sql_admin_password" { + description = "Azure SQL administrator password (SQL server authentication)" + value = random_password.sql_admin_password.result + sensitive = true +} + +output "sql_database_name" { + description = "Azure SQL database name" + value = azurerm_sql_database.sql_db.name +} + +output "vm_name" { + description = "Azure virtual machine name" + value = azurerm_virtual_machine.vm.name +} + +output "vm_client_id" { + description = "Azure VM system-assigned identity client ID" + value = data.azuread_service_principal.vm_sp.application_id +} + +output "vm_principal_id" { + description = "Azure VM system-assigned identity principal ID" + value = azurerm_virtual_machine.vm.identity.0.principal_id +} + +output "vm_ip_address" { + description = "Azure virtual machine public IP" + value = data.azurerm_public_ip.vm_ip.ip_address +} + +output "vm_admin_name" { + description = "Azure virtual machine admin user name" + value = var.vm_admin_name +} + +output "vm_user_ssh_private_key" { + description = "Azure virtual machine admin user private SSH key" + value = local.private_key_pem + sensitive = true +} + +output "vm_user_ssh_openssh_key" { + description = "Azure virtual machine admin user SSH public key" + value = local.public_key_openssh + sensitive = true +} + +output "app_sp_client_id" { + description = "Service principal client ID for application user" + value = azuread_application.app.application_id +} + +output "app_name" { + description = "Service principal name for application user" + value = azuread_application.app.name +} + +output "app_sp_client_secret" { + description = "Service principal client secret for application user" + value = random_password.app_sp_password.result + sensitive = true +} + +output "user_assigned_identity_name" { + description = "User-assigned identity for the Azure virtual machine" + value = azurerm_user_assigned_identity.vm_user_id.name +} + +output "user_assigned_identity_client_id" { + description = "User-assigned identity client ID" + value = azurerm_user_assigned_identity.vm_user_id.client_id +} diff --git a/examples/simple/simple.go b/examples/simple/simple.go index 67f88aa4..dcdc8472 100644 --- a/examples/simple/simple.go +++ b/examples/simple/simple.go @@ -5,12 +5,16 @@ import ( "flag" "fmt" "log" + "net/url" + "os" - _ "github.com/denisenkom/go-mssqldb" + "github.com/denisenkom/go-mssqldb/azuread" ) var ( + database = flag.String("database", "", "the database name") debug = flag.Bool("debug", false, "enable debugging") + dsn = flag.String("dsn", os.Getenv("SQLSERVER_DSN"), "complete SQL DSN") password = flag.String("password", "", "the database password") port *int = flag.Int("port", 1433, "the database port") server = flag.String("server", "", "the database server") @@ -20,24 +24,35 @@ var ( func main() { flag.Parse() - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) + var connString string + + if *dsn == "" { + if *debug { + fmt.Printf(" server: %s\n", *server) + fmt.Printf(" port: %d\n", *port) + fmt.Printf(" user: %s\n", *user) + fmt.Printf(" password: %s\n", *password) + fmt.Printf(" database: %s\n", *database) + } + + connString = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port, url.QueryEscape(*database)) + } else { + connString = *dsn } - connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port) if *debug { - fmt.Printf(" connString:%s\n", connString) + fmt.Printf(" dsn: %s\n", connString) } - conn, err := sql.Open("mssql", connString) + + conn, err := sql.Open(azuread.DriverName, connString) if err != nil { log.Fatal("Open connection failed:", err.Error()) } defer conn.Close() - stmt, err := conn.Prepare("select 1, 'abc'") + stmt, err := conn.Prepare("select 1, 'abc', suser_name()") if err != nil { log.Fatal("Prepare failed:", err.Error()) } @@ -46,12 +61,14 @@ func main() { row := stmt.QueryRow() var somenumber int64 var somechars string - err = row.Scan(&somenumber, &somechars) + var someuser string + err = row.Scan(&somenumber, &somechars, &someuser) if err != nil { log.Fatal("Scan failed:", err.Error()) } - fmt.Printf("somenumber:%d\n", somenumber) - fmt.Printf("somechars:%s\n", somechars) + fmt.Printf("number: %d\n", somenumber) + fmt.Printf("chars: %s\n", somechars) + fmt.Printf("user: %s\n", someuser) fmt.Printf("bye\n") } diff --git a/examples/tvp/tvp.go b/examples/tvp/tvp.go index a07bb652..eae614ef 100644 --- a/examples/tvp/tvp.go +++ b/examples/tvp/tvp.go @@ -1,3 +1,5 @@ +// +build go1.9 + package main import (