Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Detect changes in lists and sets #3147

Merged
merged 15 commits into from
Nov 5, 2024
83 changes: 83 additions & 0 deletions pkg/acceptance/bettertestspoc/assert/commons.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ package assert
import (
"errors"
"fmt"
"slices"
"strconv"
"strings"
"testing"

"golang.org/x/exp/maps"

"github.com/hashicorp/terraform-plugin-testing/helper/resource"
"github.com/hashicorp/terraform-plugin-testing/terraform"
)
Expand Down Expand Up @@ -101,3 +106,81 @@ func AssertThatObject(t *testing.T, objectAssert InPlaceAssertionVerifier) {
t.Helper()
objectAssert.VerifyAll(t)
}

func ContainsExactlyInAnyOrder(resourceKey string, attributePath string, expectedItems []map[string]string) resource.TestCheckFunc {
return func(state *terraform.State) error {
var actualItems []map[string]string
var resourceValue *terraform.ResourceState

if value, ok := state.RootModule().Resources[resourceKey]; ok {
resourceValue = value
} else {
return fmt.Errorf("resource %s not found", resourceKey)
}

// Allocate space for actualItems and assert length
for attrKey, attrValue := range resourceValue.Primary.Attributes {
if strings.HasPrefix(attrKey, attributePath) {
attr := strings.TrimPrefix(attrKey, attributePath+".")

if attr == "#" {
attrValueLen, err := strconv.Atoi(attrValue)
if err != nil {
return fmt.Errorf("failed to convert length of the attribute %s: %w", attrKey, err)
}
if len(expectedItems) != attrValueLen {
return fmt.Errorf("expected to find %d items in %s, but found %d", len(expectedItems), attributePath, attrValueLen)
}

actualItems = make([]map[string]string, attrValueLen)
for i := range actualItems {
actualItems[i] = make(map[string]string)
}
}
}
}

// Gather all actual items
for attrKey, attrValue := range resourceValue.Primary.Attributes {
if strings.HasPrefix(attrKey, attributePath) {
attr := strings.TrimPrefix(attrKey, attributePath+".")

if strings.HasSuffix(attr, "%") || strings.HasSuffix(attr, "#") {
continue
}

attrParts := strings.SplitN(attr, ".", 2)
index, indexErr := strconv.Atoi(attrParts[0])
isIndex := indexErr == nil

if len(attrParts) > 1 && isIndex {
itemKey := attrParts[1]
actualItems[index][itemKey] = attrValue
}
}
}

errs := make([]error, 0)
for _, actualItem := range actualItems {
found := false
if slices.ContainsFunc(expectedItems, func(expected map[string]string) bool { return maps.Equal(expected, actualItem) }) {
found = true
}
sfc-gh-jmichalak marked this conversation as resolved.
Show resolved Hide resolved
if !found {
errs = append(errs, fmt.Errorf("unexpected item found: %s", actualItem))
}
}

for _, expectedItem := range expectedItems {
found := false
if slices.ContainsFunc(actualItems, func(actual map[string]string) bool { return maps.Equal(actual, expectedItem) }) {
found = true
}
if !found {
errs = append(errs, fmt.Errorf("expected item to be found, but it wasn't: %s", expectedItem))
}
}

return errors.Join(errs...)
}
}
16 changes: 8 additions & 8 deletions pkg/acceptance/helpers/database_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ func (c *DatabaseClient) CreateDatabaseWithOptions(t *testing.T, id sdk.AccountO
return database, c.DropDatabaseFunc(t, id)
}

func (c *DatabaseClient) Alter(t *testing.T, id sdk.AccountObjectIdentifier, opts *sdk.AlterDatabaseOptions) {
t.Helper()
ctx := context.Background()

err := c.client().Alter(ctx, id, opts)
require.NoError(t, err)
}

func (c *DatabaseClient) DropDatabaseFunc(t *testing.T, id sdk.AccountObjectIdentifier) func() {
t.Helper()
return func() { require.NoError(t, c.DropDatabase(t, id)) }
Expand Down Expand Up @@ -192,3 +184,11 @@ func (c *DatabaseClient) ShowAllReplicationDatabases(t *testing.T) ([]sdk.Replic

return c.context.client.ReplicationFunctions.ShowReplicationDatabases(ctx, nil)
}

func (c *DatabaseClient) Alter(t *testing.T, id sdk.AccountObjectIdentifier, opts *sdk.AlterDatabaseOptions) {
t.Helper()
ctx := context.Background()

err := c.client().Alter(ctx, id, opts)
require.NoError(t, err)
}
16 changes: 8 additions & 8 deletions pkg/acceptance/helpers/schema_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ func (c *SchemaClient) CreateSchemaWithOpts(t *testing.T, id sdk.DatabaseObjectI
return schema, c.DropSchemaFunc(t, id)
}

func (c *SchemaClient) Alter(t *testing.T, id sdk.DatabaseObjectIdentifier, opts *sdk.AlterSchemaOptions) {
t.Helper()
ctx := context.Background()

err := c.client().Alter(ctx, id, opts)
require.NoError(t, err)
}

func (c *SchemaClient) DropSchemaFunc(t *testing.T, id sdk.DatabaseObjectIdentifier) func() {
t.Helper()
ctx := context.Background()
Expand Down Expand Up @@ -112,3 +104,11 @@ func (c *SchemaClient) ShowWithOptions(t *testing.T, opts *sdk.ShowSchemaOptions
require.NoError(t, err)
return schemas
}

func (c *SchemaClient) Alter(t *testing.T, id sdk.DatabaseObjectIdentifier, opts *sdk.AlterSchemaOptions) {
t.Helper()
ctx := context.Background()

err := c.client().Alter(ctx, id, opts)
require.NoError(t, err)
}
10 changes: 9 additions & 1 deletion pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strings"
"time"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/testenvs"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/datasources"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider/docs"
Expand Down Expand Up @@ -420,7 +422,7 @@ func Provider() *schema.Provider {
}

func getResources() map[string]*schema.Resource {
return map[string]*schema.Resource{
resourceList := map[string]*schema.Resource{
"snowflake_account": resources.Account(),
"snowflake_account_role": resources.AccountRole(),
"snowflake_account_password_policy_attachment": resources.AccountPasswordPolicyAttachment(),
Expand Down Expand Up @@ -501,6 +503,12 @@ func getResources() map[string]*schema.Resource {
"snowflake_view": resources.View(),
"snowflake_warehouse": resources.Warehouse(),
}

if os.Getenv(string(testenvs.EnableObjectRenamingTest)) != "" {
resourceList["snowflake_object_renaming"] = resources.ObjectRenamingListsAndSets()
}

return resourceList
}

func getDataSources() map[string]*schema.Resource {
Expand Down
Loading
Loading