Skip to content

Commit

Permalink
Merge pull request #52 from beandrad/ia.31-stringsplit-user
Browse files Browse the repository at this point in the history
Fix concurrency issue on user create/update
  • Loading branch information
magne authored Dec 16, 2022
2 parents c0e0c1d + bef05c8 commit 8606b88
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 11 deletions.
97 changes: 92 additions & 5 deletions mssql/resource_user_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package mssql

import (
"fmt"
"os"
"testing"
"fmt"
"os"
"testing"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
)

func TestAccUser_Local_Instance(t *testing.T) {
Expand Down Expand Up @@ -44,6 +44,21 @@ func TestAccUser_Local_Instance(t *testing.T) {
})
}

func TestAccMultipleUsers_Local_Instance(t *testing.T) {
resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
IsUnitTest: runLocalAccTests,
ProviderFactories: testAccProviders,
CheckDestroy: func(state *terraform.State) error { return testAccCheckUserDestroy(state) },
Steps: []resource.TestStep{
{
Config: testAccCheckMultipleUsers(t, "instance", "login", map[string]interface{}{"username": "instance", "login_name": "user_instance", "login_password": "valueIsH8kd$¡", "roles": "[\"db_owner\"]"}, 4),
Check: resource.ComposeTestCheckFunc(getMultipleUsersExistAccCheck(4)...),
},
},
})
}

func TestAccUser_Azure_Instance(t *testing.T) {
resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Expand Down Expand Up @@ -478,6 +493,49 @@ func testAccCheckUser(t *testing.T, name string, login string, data map[string]i
return res
}

func testAccCheckMultipleUsers(t *testing.T, name string, login string, data map[string]interface{}, count int) string {
text := `{{ if .login_name }}
resource "mssql_login" "{{ .name }}" {
count = {{ .count }}
server {
host = "{{ .host }}"
{{if eq .login "fedauth"}}azuread_default_chain_auth {}{{ else if eq .login "msi"}}azuread_managed_identity_auth {}{{ else if eq .login "azure" }}azure_login {}{{ else }}login {}{{ end }}
}
login_name = "{{ .login_name }}-${count.index}"
password = "{{ .login_password }}"
}
{{ end }}
resource "mssql_user" "{{ .name }}" {
count = {{ .count }}
server {
host = "{{ .host }}"
{{if eq .login "fedauth"}}azuread_default_chain_auth {}{{ else if eq .login "msi"}}azuread_managed_identity_auth {}{{ else if eq .login "azure" }}azure_login {}{{ else }}login {}{{ end }}
}
{{ with .database }}database = "{{ . }}"{{ end }}
username = "{{ .username }}-${count.index}"
{{ with .password }}password = "{{ . }}"{{ end }}
{{ with .login_name }}login_name = "{{ . }}-${count.index}"{{ end }}
{{ with .default_schema }}default_schema = "{{ . }}"{{ end }}
{{ with .default_language }}default_language = "{{ . }}"{{ end }}
{{ with .roles }}roles = {{ . }}{{ end }}
}`
data["name"] = name
data["login"] = login
data["count"] = count
if login == "fedauth" || login == "msi" || login == "azure" {
data["host"] = os.Getenv("TF_ACC_SQL_SERVER")
} else if login == "login" {
data["host"] = "localhost"
} else {
t.Fatalf("login expected to be one of 'login', 'azure', 'msi', 'fedauth', got %s", login)
}
res, err := templateToString(name, text, data)
if err != nil {
t.Fatalf("%s", err)
}
return res
}

func testAccCheckUserDestroy(state *terraform.State) error {
for _, rs := range state.RootModule().Resources {
if rs.Type != "mssql_user" {
Expand Down Expand Up @@ -630,3 +688,32 @@ func testAccCheckExternalUserWorks(resource string, tenantId, clientId, clientSe
return nil
}
}

func getMultipleUsersExistAccCheck(count int) []resource.TestCheckFunc {
checkFuncs := []resource.TestCheckFunc{}
for i := 0; i < count; i++ {
checkFuncs = append(checkFuncs, []resource.TestCheckFunc{
testAccCheckUserExists(fmt.Sprintf("mssql_user.instance.%v", i)),
testAccCheckDatabaseUserWorks(fmt.Sprintf("mssql_user.instance.%v", i), fmt.Sprintf("user_instance-%v", i), "valueIsH8kd$¡"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "database", "master"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "username", fmt.Sprintf("instance-%v", i)),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "login_name", fmt.Sprintf("user_instance-%v", i)),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "authentication_type", "INSTANCE"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "default_schema", "dbo"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "default_language", ""),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "roles.#", "1"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "roles.0", "db_owner"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "server.#", "1"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "server.0.host", "localhost"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "server.0.port", "1433"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "server.0.login.#", "1"),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "server.0.login.0.username", os.Getenv("MSSQL_USERNAME")),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "server.0.login.0.password", os.Getenv("MSSQL_PASSWORD")),
resource.TestCheckResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "server.0.azure_login.#", "0"),
resource.TestCheckResourceAttrSet(fmt.Sprintf("mssql_user.instance.%v", i), "principal_id"),
resource.TestCheckNoResourceAttr(fmt.Sprintf("mssql_user.instance.%v", i), "password"),
}...,
)
}
return checkFuncs
}
22 changes: 16 additions & 6 deletions sql/user.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package sql

import (
"context"
"database/sql"
"strings"
"github.com/betr-io/terraform-provider-mssql/mssql/model"
"context"
"database/sql"
"github.com/betr-io/terraform-provider-mssql/mssql/model"
"strings"
)

func (c *Connector) GetUser(ctx context.Context, database, username string) (*model.User, error) {
Expand Down Expand Up @@ -120,7 +120,10 @@ func (c *Connector) CreateUser(ctx context.Context, database string, user *model
'DEFAULT_LANGUAGE = ' + Coalesce(QuoteName(@language), 'NONE')
END
END
IF exists (select compatibility_level FROM sys.databases where name = db_name() and compatibility_level < 130)
BEGIN TRANSACTION;
EXEC sp_getapplock @Resource = 'create_func', @LockMode = 'Exclusive';
IF exists (select compatibility_level FROM sys.databases where name = db_name() and compatibility_level < 130) AND objectproperty(object_id('String_Split'), 'isProcedure') IS NULL
BEGIN
DECLARE @sql NVARCHAR(MAX);
SET @sql = N'Create FUNCTION [dbo].[String_Split]
Expand All @@ -146,6 +149,8 @@ func (c *Connector) CreateUser(ctx context.Context, database string, user *model
)';
EXEC sp_executesql @sql;
END
EXEC sp_releaseapplock @Resource = 'create_func';
COMMIT TRANSACTION;
SET @stmt = @stmt + '; ' +
'DECLARE role_cur CURSOR FOR SELECT name FROM ' + QuoteName(@database) + '.[sys].[database_principals] WHERE type = ''R'' AND name != ''public'' AND name COLLATE SQL_Latin1_General_CP1_CI_AS IN (SELECT value FROM String_Split(' + QuoteName(@roles, '''') + ', '',''));' +
'DECLARE @role nvarchar(max);' +
Expand Down Expand Up @@ -194,7 +199,10 @@ func (c *Connector) UpdateUser(ctx context.Context, database string, user *model
BEGIN
SET @stmt = @stmt + ', DEFAULT_LANGUAGE = ' + Coalesce(QuoteName(@language), 'NONE')
END
IF exists (select compatibility_level FROM sys.databases where name = db_name() and compatibility_level < 130)
BEGIN TRANSACTION;
EXEC sp_getapplock @Resource = 'create_func', @LockMode = 'Exclusive';
IF exists (select compatibility_level FROM sys.databases where name = db_name() and compatibility_level < 130) AND objectproperty(object_id('String_Split'), 'isProcedure') IS NULL
BEGIN
DECLARE @sql NVARCHAR(MAX);
SET @sql = N'Create FUNCTION [dbo].[String_Split]
Expand All @@ -220,6 +228,8 @@ func (c *Connector) UpdateUser(ctx context.Context, database string, user *model
)';
EXEC sp_executesql @sql;
END
EXEC sp_releaseapplock @Resource = 'create_func';
COMMIT TRANSACTION;
SET @stmt = @stmt + '; ' +
'DECLARE @sql nvarchar(max);' +
'DECLARE @role nvarchar(max);' +
Expand Down

0 comments on commit 8606b88

Please sign in to comment.