diff --git a/pkg/acceptance/helpers/account_client.go b/pkg/acceptance/helpers/account_client.go new file mode 100644 index 0000000000..bf78e6a289 --- /dev/null +++ b/pkg/acceptance/helpers/account_client.go @@ -0,0 +1,44 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type AccountClient struct { + context *TestClientContext +} + +func NewAccountClient(context *TestClientContext) *AccountClient { + return &AccountClient{ + context: context, + } +} + +func (c *AccountClient) client() sdk.Accounts { + return c.context.client.Accounts +} + +// GetAccountIdentifier gets the account identifier from Snowflake API, by fetching the account locator +// and by filtering the list of accounts in replication accounts by it (because there is no direct way to get). +func (c *AccountClient) GetAccountIdentifier(t *testing.T) sdk.AccountIdentifier { + t.Helper() + ctx := context.Background() + + currentAccountLocator, err := c.context.client.ContextFunctions.CurrentAccount(ctx) + require.NoError(t, err) + + replicationAccounts, err := c.context.client.ReplicationFunctions.ShowReplicationAccounts(ctx) + require.NoError(t, err) + + for _, replicationAccount := range replicationAccounts { + if replicationAccount.AccountLocator == currentAccountLocator { + return sdk.NewAccountIdentifier(replicationAccount.OrganizationName, replicationAccount.AccountName) + } + } + t.Fatal("could not find the account identifier for the locator") + return sdk.AccountIdentifier{} +} diff --git a/pkg/acceptance/helpers/api_integration_client.go b/pkg/acceptance/helpers/api_integration_client.go new file mode 100644 index 0000000000..114d2a7297 --- /dev/null +++ b/pkg/acceptance/helpers/api_integration_client.go @@ -0,0 +1,52 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type ApiIntegrationClient struct { + context *TestClientContext +} + +func NewApiIntegrationClient(context *TestClientContext) *ApiIntegrationClient { + return &ApiIntegrationClient{ + context: context, + } +} + +func (c *ApiIntegrationClient) client() sdk.ApiIntegrations { + return c.context.client.ApiIntegrations +} + +func (c *ApiIntegrationClient) CreateApiIntegration(t *testing.T) (*sdk.ApiIntegration, func()) { + t.Helper() + ctx := context.Background() + + id := sdk.NewAccountObjectIdentifier(random.AlphanumericN(12)) + apiAllowedPrefixes := []sdk.ApiIntegrationEndpointPrefix{{Path: "https://xyz.execute-api.us-west-2.amazonaws.com/production"}} + req := sdk.NewCreateApiIntegrationRequest(id, apiAllowedPrefixes, true) + req.WithAwsApiProviderParams(sdk.NewAwsApiParamsRequest(sdk.ApiIntegrationAwsApiGateway, "arn:aws:iam::123456789012:role/hello_cloud_account_role")) + + err := c.client().Create(ctx, req) + require.NoError(t, err) + + apiIntegration, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return apiIntegration, c.DropApiIntegrationFunc(t, id) +} + +func (c *ApiIntegrationClient) DropApiIntegrationFunc(t *testing.T, id sdk.AccountObjectIdentifier) func() { + t.Helper() + ctx := context.Background() + + return func() { + err := c.client().Drop(ctx, sdk.NewDropApiIntegrationRequest(id).WithIfExists(sdk.Bool(true))) + require.NoError(t, err) + } +} diff --git a/pkg/acceptance/helpers/application_package_client.go b/pkg/acceptance/helpers/application_package_client.go index f5970a53cd..dfaf69589b 100644 --- a/pkg/acceptance/helpers/application_package_client.go +++ b/pkg/acceptance/helpers/application_package_client.go @@ -2,6 +2,7 @@ package helpers import ( "context" + "fmt" "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" @@ -56,3 +57,17 @@ func (c *ApplicationPackageClient) AddApplicationPackageVersion(t *testing.T, id err := c.client().Alter(ctx, sdk.NewAlterApplicationPackageRequest(id).WithAddVersion(sdk.NewAddVersionRequest(using).WithVersionIdentifier(sdk.String(versionName)))) require.NoError(t, err) } + +func (c *ApplicationPackageClient) ShowVersions(t *testing.T, id sdk.AccountObjectIdentifier) []ApplicationPackageVersion { + t.Helper() + + var versions []ApplicationPackageVersion + err := c.context.client.QueryForTests(context.Background(), &versions, fmt.Sprintf(`SHOW VERSIONS IN APPLICATION PACKAGE %s`, id.FullyQualifiedName())) + require.NoError(t, err) + return versions +} + +type ApplicationPackageVersion struct { + Version string `json:"version"` + Patch int `json:"patch"` +} diff --git a/pkg/acceptance/helpers/database_client.go b/pkg/acceptance/helpers/database_client.go index 4f56f70c90..ed9a8d239c 100644 --- a/pkg/acceptance/helpers/database_client.go +++ b/pkg/acceptance/helpers/database_client.go @@ -36,10 +36,13 @@ func (c *DatabaseClient) CreateDatabaseWithName(t *testing.T, name string) (*sdk func (c *DatabaseClient) CreateDatabaseWithOptions(t *testing.T, id sdk.AccountObjectIdentifier, opts *sdk.CreateDatabaseOptions) (*sdk.Database, func()) { t.Helper() ctx := context.Background() + err := c.client().Create(ctx, id, opts) require.NoError(t, err) + database, err := c.client().ShowByID(ctx, id) require.NoError(t, err) + return database, c.DropDatabaseFunc(t, id) } diff --git a/pkg/acceptance/helpers/parameter_client.go b/pkg/acceptance/helpers/parameter_client.go new file mode 100644 index 0000000000..c0e62d79ed --- /dev/null +++ b/pkg/acceptance/helpers/parameter_client.go @@ -0,0 +1,40 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type ParameterClient struct { + context *TestClientContext +} + +func NewParameterClient(context *TestClientContext) *ParameterClient { + return &ParameterClient{ + context: context, + } +} + +func (c *ParameterClient) client() sdk.Parameters { + return c.context.client.Parameters +} + +func (c *ParameterClient) UpdateAccountParameterTemporarily(t *testing.T, parameter sdk.AccountParameter, newValue string) func() { + t.Helper() + ctx := context.Background() + + param, err := c.client().ShowAccountParameter(ctx, parameter) + require.NoError(t, err) + oldValue := param.Value + + err = c.client().SetAccountParameter(ctx, parameter, newValue) + require.NoError(t, err) + + return func() { + err = c.client().SetAccountParameter(ctx, parameter, oldValue) + require.NoError(t, err) + } +} diff --git a/pkg/acceptance/helpers/role_client.go b/pkg/acceptance/helpers/role_client.go index bb4fd7c12f..ec0223a671 100644 --- a/pkg/acceptance/helpers/role_client.go +++ b/pkg/acceptance/helpers/role_client.go @@ -110,3 +110,12 @@ func (c *RoleClient) GrantOwnershipOnAccountObject(t *testing.T, roleId sdk.Acco ) require.NoError(t, err) } + +// TODO: move later to grants client +func (c *RoleClient) GrantPrivilegeOnDatabaseToShare(t *testing.T, databaseId sdk.AccountObjectIdentifier, shareId sdk.AccountObjectIdentifier) { + t.Helper() + ctx := context.Background() + + err := c.context.client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeReferenceUsage}, &sdk.ShareGrantOn{Database: databaseId}, shareId) + require.NoError(t, err) +} diff --git a/pkg/acceptance/helpers/row_access_policy_client.go b/pkg/acceptance/helpers/row_access_policy_client.go new file mode 100644 index 0000000000..9e68989c6d --- /dev/null +++ b/pkg/acceptance/helpers/row_access_policy_client.go @@ -0,0 +1,84 @@ +package helpers + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type RowAccessPolicyClient struct { + context *TestClientContext +} + +func NewRowAccessPolicyClient(context *TestClientContext) *RowAccessPolicyClient { + return &RowAccessPolicyClient{ + context: context, + } +} + +func (c *RowAccessPolicyClient) client() sdk.RowAccessPolicies { + return c.context.client.RowAccessPolicies +} + +func (c *RowAccessPolicyClient) CreateRowAccessPolicy(t *testing.T) (*sdk.RowAccessPolicy, func()) { + t.Helper() + ctx := context.Background() + + id := c.context.newSchemaObjectIdentifier(random.AlphanumericN(12)) + arg := sdk.NewCreateRowAccessPolicyArgsRequest("A", sdk.DataTypeNumber) + body := "true" + createRequest := sdk.NewCreateRowAccessPolicyRequest(id, []sdk.CreateRowAccessPolicyArgsRequest{*arg}, body) + + err := c.client().Create(ctx, createRequest) + require.NoError(t, err) + + rowAccessPolicy, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return rowAccessPolicy, c.DropRowAccessPolicyFunc(t, id) +} + +func (c *RowAccessPolicyClient) DropRowAccessPolicyFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() { + t.Helper() + ctx := context.Background() + + return func() { + err := c.client().Drop(ctx, sdk.NewDropRowAccessPolicyRequest(id).WithIfExists(sdk.Bool(true))) + require.NoError(t, err) + } +} + +// GetRowAccessPolicyFor is based on https://docs.snowflake.com/en/user-guide/security-row-intro#obtain-database-objects-with-a-row-access-policy. +// TODO: extract getting row access policies as resource (like getting tag in system functions) +func (c *RowAccessPolicyClient) GetRowAccessPolicyFor(t *testing.T, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) (*PolicyReference, error) { + t.Helper() + ctx := context.Background() + + s := &PolicyReference{} + policyReferencesId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), "INFORMATION_SCHEMA", "POLICY_REFERENCES") + err := c.context.client.QueryOneForTests(ctx, s, fmt.Sprintf(`SELECT * FROM TABLE(%s(REF_ENTITY_NAME => '%s', REF_ENTITY_DOMAIN => '%v'))`, policyReferencesId.FullyQualifiedName(), id.FullyQualifiedName(), objectType)) + + return s, err +} + +type PolicyReference struct { + PolicyDb string `db:"POLICY_DB"` + PolicySchema string `db:"POLICY_SCHEMA"` + PolicyName string `db:"POLICY_NAME"` + PolicyKind string `db:"POLICY_KIND"` + RefDatabaseName string `db:"REF_DATABASE_NAME"` + RefSchemaName string `db:"REF_SCHEMA_NAME"` + RefEntityName string `db:"REF_ENTITY_NAME"` + RefEntityDomain string `db:"REF_ENTITY_DOMAIN"` + RefColumnName sql.NullString `db:"REF_COLUMN_NAME"` + RefArgColumnNames string `db:"REF_ARG_COLUMN_NAMES"` + TagDatabase sql.NullString `db:"TAG_DATABASE"` + TagSchema sql.NullString `db:"TAG_SCHEMA"` + TagName sql.NullString `db:"TAG_NAME"` + PolicyStatus string `db:"POLICY_STATUS"` +} diff --git a/pkg/acceptance/helpers/share_client.go b/pkg/acceptance/helpers/share_client.go new file mode 100644 index 0000000000..e4dc80e9ed --- /dev/null +++ b/pkg/acceptance/helpers/share_client.go @@ -0,0 +1,70 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type ShareClient struct { + context *TestClientContext +} + +func NewShareClient(context *TestClientContext) *ShareClient { + return &ShareClient{ + context: context, + } +} + +func (c *ShareClient) client() sdk.Shares { + return c.context.client.Shares +} + +func (c *ShareClient) CreateShare(t *testing.T) (*sdk.Share, func()) { + t.Helper() + // TODO(SNOW-1058419): Try with identifier containing dot during identifiers rework + return c.CreateShareWithName(t, random.AlphanumericN(12)) +} + +func (c *ShareClient) CreateShareWithName(t *testing.T, name string) (*sdk.Share, func()) { + t.Helper() + return c.CreateShareWithOptions(t, sdk.NewAccountObjectIdentifier(name), &sdk.CreateShareOptions{}) +} + +func (c *ShareClient) CreateShareWithOptions(t *testing.T, id sdk.AccountObjectIdentifier, opts *sdk.CreateShareOptions) (*sdk.Share, func()) { + t.Helper() + ctx := context.Background() + + err := c.client().Create(ctx, id, opts) + require.NoError(t, err) + + share, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return share, c.DropShareFunc(t, id) +} + +func (c *ShareClient) DropShareFunc(t *testing.T, id sdk.AccountObjectIdentifier) func() { + t.Helper() + ctx := context.Background() + + return func() { + err := c.client().Drop(ctx, id, &sdk.DropShareOptions{IfExists: sdk.Bool(true)}) + require.NoError(t, err) + } +} + +func (c *ShareClient) SetAccountOnShare(t *testing.T, accountId sdk.AccountIdentifier, shareId sdk.AccountObjectIdentifier) { + t.Helper() + ctx := context.Background() + + err := c.client().Alter(ctx, shareId, &sdk.AlterShareOptions{ + Set: &sdk.ShareSet{ + Accounts: []sdk.AccountIdentifier{accountId}, + }, + }) + require.NoError(t, err) +} diff --git a/pkg/acceptance/helpers/stage_client.go b/pkg/acceptance/helpers/stage_client.go index bda41dc200..bbad106f06 100644 --- a/pkg/acceptance/helpers/stage_client.go +++ b/pkg/acceptance/helpers/stage_client.go @@ -12,6 +12,10 @@ import ( "github.com/stretchr/testify/require" ) +const ( + nycWeatherDataURL = "s3://snowflake-workshop-lab/weather-nyc" +) + type StageClient struct { context *TestClientContext } @@ -26,11 +30,11 @@ func (c *StageClient) client() sdk.Stages { return c.context.client.Stages } -func (c *StageClient) CreateStageWithURL(t *testing.T, id sdk.SchemaObjectIdentifier, url string) (*sdk.Stage, func()) { +func (c *StageClient) CreateStageWithURL(t *testing.T, id sdk.SchemaObjectIdentifier) (*sdk.Stage, func()) { t.Helper() ctx := context.Background() err := c.client().CreateOnS3(ctx, sdk.NewCreateOnS3StageRequest(id). - WithExternalStageParams(sdk.NewExternalS3StageParamsRequest(url))) + WithExternalStageParams(sdk.NewExternalS3StageParamsRequest(nycWeatherDataURL))) require.NoError(t, err) stage, err := c.client().ShowByID(ctx, id) diff --git a/pkg/acceptance/helpers/table_client.go b/pkg/acceptance/helpers/table_client.go index 6c35953d35..eb5e9fa08d 100644 --- a/pkg/acceptance/helpers/table_client.go +++ b/pkg/acceptance/helpers/table_client.go @@ -2,7 +2,9 @@ package helpers import ( "context" + "database/sql" "errors" + "fmt" "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" @@ -70,3 +72,63 @@ func (c *TableClient) DropTableFunc(t *testing.T, id sdk.SchemaObjectIdentifier) require.NoError(t, dropErr) } } + +// GetTableColumnsFor is based on https://docs.snowflake.com/en/sql-reference/info-schema/columns. +// TODO: extract getting table columns as resource (like getting tag in system functions) +func (c *TableClient) GetTableColumnsFor(t *testing.T, tableId sdk.SchemaObjectIdentifier) []InformationSchemaColumns { + t.Helper() + ctx := context.Background() + + var columns []InformationSchemaColumns + query := fmt.Sprintf("SELECT * FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s' ORDER BY ordinal_position", tableId.SchemaName(), tableId.Name()) + err := c.context.client.QueryForTests(ctx, &columns, query) + require.NoError(t, err) + + return columns +} + +type InformationSchemaColumns struct { + TableCatalog string `db:"TABLE_CATALOG"` + TableSchema string `db:"TABLE_SCHEMA"` + TableName string `db:"TABLE_NAME"` + ColumnName string `db:"COLUMN_NAME"` + OrdinalPosition string `db:"ORDINAL_POSITION"` + ColumnDefault sql.NullString `db:"COLUMN_DEFAULT"` + IsNullable string `db:"IS_NULLABLE"` + DataType string `db:"DATA_TYPE"` + CharacterMaximumLength sql.NullString `db:"CHARACTER_MAXIMUM_LENGTH"` + CharacterOctetLength sql.NullString `db:"CHARACTER_OCTET_LENGTH"` + NumericPrecision sql.NullString `db:"NUMERIC_PRECISION"` + NumericPrecisionRadix sql.NullString `db:"NUMERIC_PRECISION_RADIX"` + NumericScale sql.NullString `db:"NUMERIC_SCALE"` + DatetimePrecision sql.NullString `db:"DATETIME_PRECISION"` + IntervalType sql.NullString `db:"INTERVAL_TYPE"` + IntervalPrecision sql.NullString `db:"INTERVAL_PRECISION"` + CharacterSetCatalog sql.NullString `db:"CHARACTER_SET_CATALOG"` + CharacterSetSchema sql.NullString `db:"CHARACTER_SET_SCHEMA"` + CharacterSetName sql.NullString `db:"CHARACTER_SET_NAME"` + CollationCatalog sql.NullString `db:"COLLATION_CATALOG"` + CollationSchema sql.NullString `db:"COLLATION_SCHEMA"` + CollationName sql.NullString `db:"COLLATION_NAME"` + DomainCatalog sql.NullString `db:"DOMAIN_CATALOG"` + DomainSchema sql.NullString `db:"DOMAIN_SCHEMA"` + DomainName sql.NullString `db:"DOMAIN_NAME"` + UdtCatalog sql.NullString `db:"UDT_CATALOG"` + UdtSchema sql.NullString `db:"UDT_SCHEMA"` + UdtName sql.NullString `db:"UDT_NAME"` + ScopeCatalog sql.NullString `db:"SCOPE_CATALOG"` + ScopeSchema sql.NullString `db:"SCOPE_SCHEMA"` + ScopeName sql.NullString `db:"SCOPE_NAME"` + MaximumCardinality sql.NullString `db:"MAXIMUM_CARDINALITY"` + DtdIdentifier sql.NullString `db:"DTD_IDENTIFIER"` + IsSelfReferencing string `db:"IS_SELF_REFERENCING"` + IsIdentity string `db:"IS_IDENTITY"` + IdentityGeneration sql.NullString `db:"IDENTITY_GENERATION"` + IdentityStart sql.NullString `db:"IDENTITY_START"` + IdentityIncrement sql.NullString `db:"IDENTITY_INCREMENT"` + IdentityMaximum sql.NullString `db:"IDENTITY_MAXIMUM"` + IdentityMinimum sql.NullString `db:"IDENTITY_MINIMUM"` + IdentityCycle sql.NullString `db:"IDENTITY_CYCLE"` + IdentityOrdered sql.NullString `db:"IDENTITY_ORDERED"` + Comment sql.NullString `db:"COMMENT"` +} diff --git a/pkg/acceptance/helpers/task_client.go b/pkg/acceptance/helpers/task_client.go new file mode 100644 index 0000000000..26d85ea068 --- /dev/null +++ b/pkg/acceptance/helpers/task_client.go @@ -0,0 +1,66 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type TaskClient struct { + context *TestClientContext +} + +func NewTaskClient(context *TestClientContext) *TaskClient { + return &TaskClient{ + context: context, + } +} + +func (c *TaskClient) client() sdk.Tasks { + return c.context.client.Tasks +} + +func (c *TaskClient) defaultCreateTaskRequest(t *testing.T) *sdk.CreateTaskRequest { + t.Helper() + id := c.context.newSchemaObjectIdentifier(random.AlphanumericN(12)) + warehouseReq := sdk.NewCreateTaskWarehouseRequest().WithWarehouse(sdk.Pointer(c.context.warehouseId())) + return sdk.NewCreateTaskRequest(id, "SELECT CURRENT_TIMESTAMP").WithWarehouse(warehouseReq) +} + +func (c *TaskClient) CreateTask(t *testing.T) (*sdk.Task, func()) { + t.Helper() + return c.CreateTaskWithRequest(t, c.defaultCreateTaskRequest(t).WithSchedule(sdk.String("60 minutes"))) +} + +func (c *TaskClient) CreateTaskWithAfter(t *testing.T, taskId sdk.SchemaObjectIdentifier) (*sdk.Task, func()) { + t.Helper() + return c.CreateTaskWithRequest(t, c.defaultCreateTaskRequest(t).WithAfter([]sdk.SchemaObjectIdentifier{taskId})) +} + +func (c *TaskClient) CreateTaskWithRequest(t *testing.T, request *sdk.CreateTaskRequest) (*sdk.Task, func()) { + t.Helper() + ctx := context.Background() + + id := request.GetName() + + err := c.client().Create(ctx, request) + require.NoError(t, err) + + task, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return task, c.DropTaskFunc(t, id) +} + +func (c *TaskClient) DropTaskFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() { + t.Helper() + ctx := context.Background() + + return func() { + err := c.client().Drop(ctx, sdk.NewDropTaskRequest(id).WithIfExists(sdk.Bool(true))) + require.NoError(t, err) + } +} diff --git a/pkg/acceptance/helpers/test_client.go b/pkg/acceptance/helpers/test_client.go index 77a3b78f74..b7b265b975 100644 --- a/pkg/acceptance/helpers/test_client.go +++ b/pkg/acceptance/helpers/test_client.go @@ -5,7 +5,9 @@ import "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" type TestClient struct { context *TestClientContext + Account *AccountClient Alert *AlertClient + ApiIntegration *ApiIntegrationClient Application *ApplicationClient ApplicationPackage *ApplicationPackageClient Context *ContextClient @@ -16,16 +18,21 @@ type TestClient struct { FileFormat *FileFormatClient MaskingPolicy *MaskingPolicyClient NetworkPolicy *NetworkPolicyClient + Parameter *ParameterClient PasswordPolicy *PasswordPolicyClient Pipe *PipeClient ResourceMonitor *ResourceMonitorClient Role *RoleClient + RowAccessPolicy *RowAccessPolicyClient Schema *SchemaClient SessionPolicy *SessionPolicyClient + Share *ShareClient Stage *StageClient Table *TableClient Tag *TagClient + Task *TaskClient User *UserClient + View *ViewClient Warehouse *WarehouseClient } @@ -38,7 +45,9 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri } return &TestClient{ context: context, + Account: NewAccountClient(context), Alert: NewAlertClient(context), + ApiIntegration: NewApiIntegrationClient(context), Application: NewApplicationClient(context), ApplicationPackage: NewApplicationPackageClient(context), Context: NewContextClient(context), @@ -49,16 +58,21 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri FileFormat: NewFileFormatClient(context), MaskingPolicy: NewMaskingPolicyClient(context), NetworkPolicy: NewNetworkPolicyClient(context), + Parameter: NewParameterClient(context), PasswordPolicy: NewPasswordPolicyClient(context), Pipe: NewPipeClient(context), ResourceMonitor: NewResourceMonitorClient(context), Role: NewRoleClient(context), + RowAccessPolicy: NewRowAccessPolicyClient(context), Schema: NewSchemaClient(context), SessionPolicy: NewSessionPolicyClient(context), + Share: NewShareClient(context), Stage: NewStageClient(context), - Tag: NewTagClient(context), Table: NewTableClient(context), + Tag: NewTagClient(context), + Task: NewTaskClient(context), User: NewUserClient(context), + View: NewViewClient(context), Warehouse: NewWarehouseClient(context), } } diff --git a/pkg/acceptance/helpers/view_client.go b/pkg/acceptance/helpers/view_client.go new file mode 100644 index 0000000000..029dd8d59e --- /dev/null +++ b/pkg/acceptance/helpers/view_client.go @@ -0,0 +1,49 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type ViewClient struct { + context *TestClientContext +} + +func NewViewClient(context *TestClientContext) *ViewClient { + return &ViewClient{ + context: context, + } +} + +func (c *ViewClient) client() sdk.Views { + return c.context.client.Views +} + +func (c *ViewClient) CreateView(t *testing.T, query string) (*sdk.View, func()) { + t.Helper() + ctx := context.Background() + + id := c.context.newSchemaObjectIdentifier(random.AlphanumericN(12)) + + err := c.client().Create(ctx, sdk.NewCreateViewRequest(id, query)) + require.NoError(t, err) + + view, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return view, c.DropViewFunc(t, id) +} + +func (c *ViewClient) DropViewFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() { + t.Helper() + ctx := context.Background() + + return func() { + err := c.client().Drop(ctx, sdk.NewDropViewRequest(id).WithIfExists(sdk.Bool(true))) + require.NoError(t, err) + } +} diff --git a/pkg/acceptance/testing.go b/pkg/acceptance/testing.go index 0ec163e939..cf05ca7823 100644 --- a/pkg/acceptance/testing.go +++ b/pkg/acceptance/testing.go @@ -118,34 +118,44 @@ var once sync.Once func TestAccPreCheck(t *testing.T) { // use singleton design pattern to ensure we only create these resources once + // there is no cleanup currently, sweepers take care of it once.Do(func() { ctx := context.Background() dbId := sdk.NewAccountObjectIdentifier(TestDatabaseName) - if err := atc.client.Databases.Create(ctx, dbId, &sdk.CreateDatabaseOptions{ - IfNotExists: sdk.Bool(true), - }); err != nil { + schemaId := sdk.NewDatabaseObjectIdentifier(TestDatabaseName, TestSchemaName) + warehouseId := sdk.NewAccountObjectIdentifier(TestWarehouseName) + warehouseId2 := sdk.NewAccountObjectIdentifier(TestWarehouseName2) + + if err := atc.client.Databases.Create(ctx, dbId, &sdk.CreateDatabaseOptions{IfNotExists: sdk.Bool(true)}); err != nil { t.Fatal(err) } - schemaId := sdk.NewDatabaseObjectIdentifier(TestDatabaseName, TestSchemaName) - if err := atc.client.Schemas.Create(ctx, schemaId, &sdk.CreateSchemaOptions{ - IfNotExists: sdk.Bool(true), - }); err != nil { + if err := atc.client.Schemas.Create(ctx, schemaId, &sdk.CreateSchemaOptions{IfNotExists: sdk.Bool(true)}); err != nil { t.Fatal(err) } - warehouseId := sdk.NewAccountObjectIdentifier(TestWarehouseName) - if err := atc.client.Warehouses.Create(ctx, warehouseId, &sdk.CreateWarehouseOptions{ - IfNotExists: sdk.Bool(true), - }); err != nil { + if err := atc.client.Warehouses.Create(ctx, warehouseId, &sdk.CreateWarehouseOptions{IfNotExists: sdk.Bool(true)}); err != nil { t.Fatal(err) } - warehouseId2 := sdk.NewAccountObjectIdentifier(TestWarehouseName2) - if err := atc.client.Warehouses.Create(ctx, warehouseId2, &sdk.CreateWarehouseOptions{ - IfNotExists: sdk.Bool(true), - }); err != nil { + if err := atc.client.Warehouses.Create(ctx, warehouseId2, &sdk.CreateWarehouseOptions{IfNotExists: sdk.Bool(true)}); err != nil { + t.Fatal(err) + } + + if err := atc.secondaryClient.Databases.Create(ctx, dbId, &sdk.CreateDatabaseOptions{IfNotExists: sdk.Bool(true)}); err != nil { + t.Fatal(err) + } + + if err := atc.secondaryClient.Schemas.Create(ctx, schemaId, &sdk.CreateSchemaOptions{IfNotExists: sdk.Bool(true)}); err != nil { + t.Fatal(err) + } + + if err := atc.secondaryClient.Warehouses.Create(ctx, warehouseId, &sdk.CreateWarehouseOptions{IfNotExists: sdk.Bool(true)}); err != nil { + t.Fatal(err) + } + + if err := atc.secondaryClient.Warehouses.Create(ctx, warehouseId2, &sdk.CreateWarehouseOptions{IfNotExists: sdk.Bool(true)}); err != nil { t.Fatal(err) } }) diff --git a/pkg/resources/database_acceptance_test.go b/pkg/resources/database_acceptance_test.go index 4c7651c7e7..3de32dc775 100644 --- a/pkg/resources/database_acceptance_test.go +++ b/pkg/resources/database_acceptance_test.go @@ -196,8 +196,6 @@ func TestAcc_Database_DefaultDataRetentionTime(t *testing.T) { return vars } - client := acc.Client(t) - resource.Test(t, resource.TestCase{ ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, PreCheck: func() { acc.TestAccPreCheck(t) }, @@ -207,7 +205,10 @@ func TestAcc_Database_DefaultDataRetentionTime(t *testing.T) { CheckDestroy: acc.CheckDestroy(t, resources.Database), Steps: []resource.TestStep{ { - PreConfig: updateAccountParameter(t, client, sdk.AccountParameterDataRetentionTimeInDays, true, "5"), + PreConfig: func() { + revertParameter := acc.TestClient().Parameter.UpdateAccountParameterTemporarily(t, sdk.AccountParameterDataRetentionTimeInDays, "5") + t.Cleanup(revertParameter) + }, ConfigDirectory: acc.ConfigurationDirectory("TestAcc_Database_DefaultDataRetentionTime/WithoutDataRetentionSet"), ConfigVariables: configVariablesWithoutDatabaseDataRetentionTime(), Check: resource.ComposeTestCheckFunc( @@ -216,7 +217,9 @@ func TestAcc_Database_DefaultDataRetentionTime(t *testing.T) { ), }, { - PreConfig: updateAccountParameter(t, client, sdk.AccountParameterDataRetentionTimeInDays, false, "10"), + PreConfig: func() { + _ = acc.TestClient().Parameter.UpdateAccountParameterTemporarily(t, sdk.AccountParameterDataRetentionTimeInDays, "10") + }, ConfigDirectory: acc.ConfigurationDirectory("TestAcc_Database_DefaultDataRetentionTime/WithoutDataRetentionSet"), ConfigVariables: configVariablesWithoutDatabaseDataRetentionTime(), Check: resource.ComposeTestCheckFunc( @@ -285,8 +288,6 @@ func TestAcc_Database_DefaultDataRetentionTime_SetOutsideOfTerraform(t *testing. return vars } - client := acc.Client(t) - resource.Test(t, resource.TestCase{ ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, PreCheck: func() { acc.TestAccPreCheck(t) }, @@ -296,7 +297,10 @@ func TestAcc_Database_DefaultDataRetentionTime_SetOutsideOfTerraform(t *testing. CheckDestroy: acc.CheckDestroy(t, resources.Database), Steps: []resource.TestStep{ { - PreConfig: updateAccountParameter(t, client, sdk.AccountParameterDataRetentionTimeInDays, true, "5"), + PreConfig: func() { + revertParameter := acc.TestClient().Parameter.UpdateAccountParameterTemporarily(t, sdk.AccountParameterDataRetentionTimeInDays, "5") + t.Cleanup(revertParameter) + }, ConfigDirectory: acc.ConfigurationDirectory("TestAcc_Database_DefaultDataRetentionTime/WithoutDataRetentionSet"), ConfigVariables: configVariablesWithoutDatabaseDataRetentionTime(), Check: resource.ComposeTestCheckFunc( @@ -314,7 +318,9 @@ func TestAcc_Database_DefaultDataRetentionTime_SetOutsideOfTerraform(t *testing. ), }, { - PreConfig: updateAccountParameter(t, client, sdk.AccountParameterDataRetentionTimeInDays, false, "10"), + PreConfig: func() { + _ = acc.TestClient().Parameter.UpdateAccountParameterTemporarily(t, sdk.AccountParameterDataRetentionTimeInDays, "10") + }, ConfigDirectory: acc.ConfigurationDirectory("TestAcc_Database_DefaultDataRetentionTime/WithDataRetentionSet"), ConfigVariables: configVariablesWithDatabaseDataRetentionTime(3), Check: resource.ComposeTestCheckFunc( diff --git a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go index 7ac5e858f8..e14ea0b905 100644 --- a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go @@ -2,7 +2,6 @@ package resources_test import ( "context" - "errors" "fmt" "regexp" "strings" @@ -17,7 +16,6 @@ import ( "github.com/hashicorp/terraform-plugin-testing/plancheck" "github.com/hashicorp/terraform-plugin-testing/terraform" "github.com/hashicorp/terraform-plugin-testing/tfversion" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1026,15 +1024,10 @@ func TestAcc_GrantPrivilegesToAccountRole_ImportedPrivileges(t *testing.T) { TerraformVersionChecks: []tfversion.TerraformVersionCheck{ tfversion.RequireAbove(tfversion.Version1_5_0), }, - CheckDestroy: func(state *terraform.State) error { - return errors.Join( - acc.CheckAccountRolePrivilegesRevoked(t)(state), - dropSharedDatabaseOnSecondaryAccount(t, sharedDatabaseName, shareName), - ) - }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), Steps: []resource.TestStep{ { - PreConfig: func() { assert.NoError(t, createSharedDatabaseOnSecondaryAccount(t, sharedDatabaseName, shareName)) }, + PreConfig: func() { createSharedDatabaseOnSecondaryAccount(t, sharedDatabaseName, shareName) }, ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToAccountRole/ImportedPrivileges"), ConfigVariables: configVariables, ConfigPlanChecks: resource.ConfigPlanChecks{ @@ -1549,29 +1542,20 @@ func TestAcc_GrantPrivilegesToAccountRole_AlwaysApply_SetAfterCreate(t *testing. }) } -func createSharedDatabaseOnSecondaryAccount(t *testing.T, databaseName string, shareName string) error { +func createSharedDatabaseOnSecondaryAccount(t *testing.T, databaseName string, shareName string) { t.Helper() - secondaryClient := acc.SecondaryClient(t) - ctx := context.Background() - accountName := acc.TestClient().Context.CurrentAccount(t) - return errors.Join( - secondaryClient.Databases.Create(ctx, sdk.NewAccountObjectIdentifier(databaseName), &sdk.CreateDatabaseOptions{}), - secondaryClient.Shares.Create(ctx, sdk.NewAccountObjectIdentifier(shareName), &sdk.CreateShareOptions{}), - secondaryClient.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeReferenceUsage}, &sdk.ShareGrantOn{Database: sdk.NewAccountObjectIdentifier(databaseName)}, sdk.NewAccountObjectIdentifier(shareName)), - secondaryClient.Shares.Alter(ctx, sdk.NewAccountObjectIdentifier(shareName), &sdk.AlterShareOptions{Set: &sdk.ShareSet{ - Accounts: []sdk.AccountIdentifier{sdk.NewAccountIdentifierFromAccountLocator(accountName)}, - }}), - ) -} -func dropSharedDatabaseOnSecondaryAccount(t *testing.T, databaseName string, shareName string) error { - t.Helper() - secondaryClient := acc.SecondaryClient(t) - ctx := context.Background() - return errors.Join( - secondaryClient.Shares.Drop(ctx, sdk.NewAccountObjectIdentifier(shareName)), - secondaryClient.Databases.Drop(ctx, sdk.NewAccountObjectIdentifier(databaseName), &sdk.DropDatabaseOptions{}), - ) + database, databaseCleanup := acc.SecondaryTestClient().Database.CreateDatabaseWithName(t, databaseName) + t.Cleanup(databaseCleanup) + + share, shareCleanup := acc.SecondaryTestClient().Share.CreateShareWithName(t, shareName) + t.Cleanup(shareCleanup) + + acc.SecondaryTestClient().Role.GrantPrivilegeOnDatabaseToShare(t, database.ID(), share.ID()) + + accountName := acc.TestClient().Context.CurrentAccount(t) + accountId := sdk.NewAccountIdentifierFromAccountLocator(accountName) + acc.SecondaryTestClient().Share.SetAccountOnShare(t, accountId, share.ID()) } func queriedAccountRolePrivilegesEqualTo(roleName sdk.AccountObjectIdentifier, privileges ...string) func(s *terraform.State) error { diff --git a/pkg/resources/grant_privileges_to_role_acceptance_test.go b/pkg/resources/grant_privileges_to_role_acceptance_test.go index ce3fca0f46..9103fc7dd0 100644 --- a/pkg/resources/grant_privileges_to_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_role_acceptance_test.go @@ -1,7 +1,6 @@ package resources_test import ( - "errors" "fmt" "regexp" "strings" @@ -14,9 +13,7 @@ import ( "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" "github.com/hashicorp/terraform-plugin-testing/plancheck" - "github.com/hashicorp/terraform-plugin-testing/terraform" "github.com/hashicorp/terraform-plugin-testing/tfversion" - "github.com/stretchr/testify/assert" ) func TestAcc_GrantPrivilegesToRole_onAccount(t *testing.T) { @@ -1103,15 +1100,10 @@ func TestAcc_GrantPrivilegesToRole_ImportedPrivileges(t *testing.T) { TerraformVersionChecks: []tfversion.TerraformVersionCheck{ tfversion.RequireAbove(tfversion.Version1_5_0), }, - CheckDestroy: func(state *terraform.State) error { - return errors.Join( - acc.CheckAccountRolePrivilegesRevoked(t)(state), - dropSharedDatabaseOnSecondaryAccount(t, sharedDatabaseName, shareName), - ) - }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), Steps: []resource.TestStep{ { - PreConfig: func() { assert.NoError(t, createSharedDatabaseOnSecondaryAccount(t, sharedDatabaseName, shareName)) }, + PreConfig: func() { createSharedDatabaseOnSecondaryAccount(t, sharedDatabaseName, shareName) }, ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToRole/ImportedPrivileges"), ConfigVariables: configVariables, ConfigPlanChecks: resource.ConfigPlanChecks{ diff --git a/pkg/resources/grant_privileges_to_share_acceptance_test.go b/pkg/resources/grant_privileges_to_share_acceptance_test.go index 0d248e940b..9efe9f9fbe 100644 --- a/pkg/resources/grant_privileges_to_share_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_share_acceptance_test.go @@ -544,7 +544,8 @@ func TestAcc_GrantPrivilegesToShare_RemoveShareOutsideTerraform(t *testing.T) { Steps: []resource.TestStep{ { PreConfig: func() { - shareCleanup = createShareOutsideTerraform(t, shareName) + _, shareCleanup = acc.TestClient().Share.CreateShareWithName(t, shareName) + t.Cleanup(shareCleanup) }, ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnCustomShare"), ConfigVariables: configVariables, @@ -591,23 +592,3 @@ func testAccCheckSharePrivilegesRevoked() func(*terraform.State) error { return nil } } - -func createShareOutsideTerraform(t *testing.T, name string) func() { - t.Helper() - client := acc.Client(t) - ctx := context.Background() - - if err := client.Shares.Create(ctx, sdk.NewAccountObjectIdentifier(name), new(sdk.CreateShareOptions)); err != nil { - if err != nil { - t.Fatal(err) - } - } - - return func() { - if err := client.Shares.Drop(ctx, sdk.NewAccountObjectIdentifier(name)); err != nil { - if err != nil { - t.Fatal(err) - } - } - } -} diff --git a/pkg/resources/helpers_test.go b/pkg/resources/helpers_test.go index 9ddee8930b..d3a1c611e6 100644 --- a/pkg/resources/helpers_test.go +++ b/pkg/resources/helpers_test.go @@ -14,7 +14,6 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-testing/terraform" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) type grantType int @@ -99,27 +98,3 @@ func queriedPrivilegesContainAtLeast(query func(client *sdk.Client, ctx context. return nil } } - -// TODO(SNOW-936093): This function should be merged with testint/helpers_test.go updateAccountParameterTemporarily function which does the same thing. -// We cannot use it right now because it requires moving the function between the packages, so both tests will be able to see it. -func updateAccountParameter(t *testing.T, client *sdk.Client, parameter sdk.AccountParameter, temporarily bool, newValue string) func() { - t.Helper() - - ctx := context.Background() - - param, err := client.Parameters.ShowAccountParameter(ctx, parameter) - require.NoError(t, err) - oldValue := param.Value - - if temporarily { - t.Cleanup(func() { - err = client.Parameters.SetAccountParameter(ctx, parameter, oldValue) - require.NoError(t, err) - }) - } - - return func() { - err = client.Parameters.SetAccountParameter(ctx, parameter, newValue) - require.NoError(t, err) - } -} diff --git a/pkg/resources/share.go b/pkg/resources/share.go index d4ac2bd164..d6c3d5efe1 100644 --- a/pkg/resources/share.go +++ b/pkg/resources/share.go @@ -245,7 +245,7 @@ func UpdateShare(d *schema.ResourceData, meta interface{}) error { func DeleteShare(d *schema.ResourceData, meta interface{}) error { client := meta.(*provider.Context).Client ctx := context.Background() - err := client.Shares.Drop(ctx, sdk.NewAccountObjectIdentifier(d.Id())) + err := client.Shares.Drop(ctx, sdk.NewAccountObjectIdentifier(d.Id()), &sdk.DropShareOptions{IfExists: sdk.Bool(true)}) if err != nil { return fmt.Errorf("error deleting share (%v) err = %w", d.Id(), err) } diff --git a/pkg/sdk/application_packages_gen_test.go b/pkg/sdk/application_packages_gen_test.go index 339846bc84..4ca815f169 100644 --- a/pkg/sdk/application_packages_gen_test.go +++ b/pkg/sdk/application_packages_gen_test.go @@ -231,7 +231,8 @@ func TestApplicationPackages_Drop(t *testing.T) { t.Run("all options", func(t *testing.T) { opts := defaultOpts() - assertOptsValidAndSQLEquals(t, opts, `DROP APPLICATION PACKAGE %s`, id.FullyQualifiedName()) + opts.IfExists = Bool(true) + assertOptsValidAndSQLEquals(t, opts, `DROP APPLICATION PACKAGE IF EXISTS %s`, id.FullyQualifiedName()) }) } diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go index 724af71f4d..4c50c36ce9 100644 --- a/pkg/sdk/masking_policy.go +++ b/pkg/sdk/masking_policy.go @@ -188,9 +188,7 @@ func (opts *DropMaskingPolicyOptions) validate() error { } func (v *maskingPolicies) Drop(ctx context.Context, id SchemaObjectIdentifier, opts *DropMaskingPolicyOptions) error { - if opts == nil { - return errors.Join(ErrNilOptions) - } + opts = createIfNil(opts) opts.name = id if err := opts.validate(); err != nil { return fmt.Errorf("validate drop options: %w", err) diff --git a/pkg/sdk/resource_monitors.go b/pkg/sdk/resource_monitors.go index 8dc9f74ce8..817c4df045 100644 --- a/pkg/sdk/resource_monitors.go +++ b/pkg/sdk/resource_monitors.go @@ -358,9 +358,7 @@ func (opts *DropResourceMonitorOptions) validate() error { } func (v *resourceMonitors) Drop(ctx context.Context, id AccountObjectIdentifier, opts *DropResourceMonitorOptions) error { - if opts == nil { - return errors.Join(ErrNilOptions) - } + opts = createIfNil(opts) opts.name = id if err := opts.validate(); err != nil { return err diff --git a/pkg/sdk/shares.go b/pkg/sdk/shares.go index 9f34bf2f13..aa92fb0790 100644 --- a/pkg/sdk/shares.go +++ b/pkg/sdk/shares.go @@ -11,7 +11,7 @@ import ( var ( _ validatable = new(CreateShareOptions) _ validatable = new(AlterShareOptions) - _ validatable = new(dropShareOptions) + _ validatable = new(DropShareOptions) _ validatable = new(ShowShareOptions) _ validatable = new(describeShareOptions) ) @@ -19,7 +19,7 @@ var ( type Shares interface { Create(ctx context.Context, id AccountObjectIdentifier, opts *CreateShareOptions) error Alter(ctx context.Context, id AccountObjectIdentifier, opts *AlterShareOptions) error - Drop(ctx context.Context, id AccountObjectIdentifier) error + Drop(ctx context.Context, id AccountObjectIdentifier, opts *DropShareOptions) error Show(ctx context.Context, opts *ShowShareOptions) ([]Share, error) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*Share, error) DescribeProvider(ctx context.Context, id AccountObjectIdentifier) (*ShareDetails, error) @@ -121,7 +121,7 @@ func (opts *CreateShareOptions) validate() error { return nil } -func (v *shares) Create(ctx context.Context, id AccountObjectIdentifier, opts *CreateShareOptions) error { +func (s *shares) Create(ctx context.Context, id AccountObjectIdentifier, opts *CreateShareOptions) error { if opts == nil { opts = &CreateShareOptions{} } @@ -133,18 +133,19 @@ func (v *shares) Create(ctx context.Context, id AccountObjectIdentifier, opts *C if err != nil { return err } - _, err = v.client.exec(ctx, sql) + _, err = s.client.exec(ctx, sql) return err } -// dropShareOptions is based on https://docs.snowflake.com/en/sql-reference/sql/drop-share. -type dropShareOptions struct { - drop bool `ddl:"static" sql:"DROP"` - share bool `ddl:"static" sql:"SHARE"` - name AccountObjectIdentifier `ddl:"identifier"` +// DropShareOptions is based on https://docs.snowflake.com/en/sql-reference/sql/drop-share. +type DropShareOptions struct { + drop bool `ddl:"static" sql:"DROP"` + share bool `ddl:"static" sql:"SHARE"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` + name AccountObjectIdentifier `ddl:"identifier"` } -func (opts *dropShareOptions) validate() error { +func (opts *DropShareOptions) validate() error { if opts == nil { return errors.Join(ErrNilOptions) } @@ -154,10 +155,9 @@ func (opts *dropShareOptions) validate() error { return nil } -func (v *shares) Drop(ctx context.Context, id AccountObjectIdentifier) error { - opts := &dropShareOptions{ - name: id, - } +func (s *shares) Drop(ctx context.Context, id AccountObjectIdentifier, opts *DropShareOptions) error { + opts = createIfNil(opts) + opts.name = id if err := opts.validate(); err != nil { return err } @@ -165,7 +165,7 @@ func (v *shares) Drop(ctx context.Context, id AccountObjectIdentifier) error { if err != nil { return err } - _, err = v.client.exec(ctx, sql) + _, err = s.client.exec(ctx, sql) return err } @@ -263,7 +263,7 @@ func (v *ShareUnset) validate() error { return nil } -func (v *shares) Alter(ctx context.Context, id AccountObjectIdentifier, opts *AlterShareOptions) error { +func (s *shares) Alter(ctx context.Context, id AccountObjectIdentifier, opts *AlterShareOptions) error { if opts == nil { opts = &AlterShareOptions{} } @@ -275,7 +275,7 @@ func (v *shares) Alter(ctx context.Context, id AccountObjectIdentifier, opts *Al if err != nil { return err } - _, err = v.client.exec(ctx, sql) + _, err = s.client.exec(ctx, sql) return err } @@ -374,7 +374,7 @@ func (opts *describeShareOptions) validate() error { return nil } -func (c *shares) DescribeProvider(ctx context.Context, id AccountObjectIdentifier) (*ShareDetails, error) { +func (s *shares) DescribeProvider(ctx context.Context, id AccountObjectIdentifier) (*ShareDetails, error) { opts := &describeShareOptions{ name: id, } @@ -383,14 +383,14 @@ func (c *shares) DescribeProvider(ctx context.Context, id AccountObjectIdentifie return nil, err } var rows []shareDetailsRow - err = c.client.query(ctx, &rows, sql) + err = s.client.query(ctx, &rows, sql) if err != nil { return nil, err } return shareDetailsFromRows(rows), nil } -func (c *shares) DescribeConsumer(ctx context.Context, id ExternalObjectIdentifier) (*ShareDetails, error) { +func (s *shares) DescribeConsumer(ctx context.Context, id ExternalObjectIdentifier) (*ShareDetails, error) { opts := &describeShareOptions{ name: id, } @@ -399,7 +399,7 @@ func (c *shares) DescribeConsumer(ctx context.Context, id ExternalObjectIdentifi return nil, err } var rows []shareDetailsRow - err = c.client.query(ctx, &rows, sql) + err = s.client.query(ctx, &rows, sql) if err != nil { return nil, err } diff --git a/pkg/sdk/shares_test.go b/pkg/sdk/shares_test.go index f394d36764..d63f7b1030 100644 --- a/pkg/sdk/shares_test.go +++ b/pkg/sdk/shares_test.go @@ -127,7 +127,7 @@ func TestShareShow(t *testing.T) { func TestShareDrop(t *testing.T) { t.Run("only name", func(t *testing.T) { - opts := &dropShareOptions{ + opts := &DropShareOptions{ name: NewAccountObjectIdentifier("myshare"), } assertOptsValidAndSQLEquals(t, opts, `DROP SHARE "myshare"`) diff --git a/pkg/sdk/sweepers.go b/pkg/sdk/sweepers.go index b974c2f55a..24a05da5de 100644 --- a/pkg/sdk/sweepers.go +++ b/pkg/sdk/sweepers.go @@ -131,7 +131,7 @@ func getShareSweeper(client *Client, suffix string) func() error { for _, share := range shares { if share.Kind == ShareKindOutbound && strings.HasPrefix(share.Name.Name(), suffix) { log.Printf("[DEBUG] Dropping share %s", share.ID().FullyQualifiedName()) - if err := client.Shares.Drop(ctx, share.ID()); err != nil { + if err := client.Shares.Drop(ctx, share.ID(), &DropShareOptions{IfExists: Bool(true)}); err != nil { return fmt.Errorf("sweeping share %s ended with error, err = %w", share.ID().FullyQualifiedName(), err) } } else { diff --git a/pkg/sdk/testint/application_packages_integration_test.go b/pkg/sdk/testint/application_packages_integration_test.go index 7e0e4bce12..138b7bb76e 100644 --- a/pkg/sdk/testint/application_packages_integration_test.go +++ b/pkg/sdk/testint/application_packages_integration_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "os" "strconv" "testing" @@ -178,22 +177,10 @@ func TestInt_ApplicationPackages(t *testing.T) { }) } -type StagedFile struct { - Name string `json:"name"` - Size int `json:"size"` -} - -type ApplicationPackageVersion struct { - Version string `json:"version"` - Patch int `json:"patch"` -} - func TestInt_ApplicationPackagesVersionAndReleaseDirective(t *testing.T) { client := testClient(t) ctx := context.Background() - databaseTest, schemaTest := testDb(t), testSchema(t) - cleanupApplicationPackageHandle := func(id sdk.AccountObjectIdentifier) func() { return func() { err := client.ApplicationPackages.Drop(ctx, sdk.NewDropApplicationPackageRequest(id)) @@ -207,7 +194,7 @@ func TestInt_ApplicationPackagesVersionAndReleaseDirective(t *testing.T) { createApplicationPackageHandle := func(t *testing.T) *sdk.ApplicationPackage { t.Helper() - id := sdk.NewAccountObjectIdentifier("snowflake_package_test") + id := sdk.RandomAccountObjectIdentifier() request := sdk.NewCreateApplicationPackageRequest(id).WithDistribution(sdk.DistributionPointer(sdk.DistributionInternal)) err := client.ApplicationPackages.Create(ctx, request) require.NoError(t, err) @@ -222,65 +209,22 @@ func TestInt_ApplicationPackagesVersionAndReleaseDirective(t *testing.T) { return e } - createStageHandle := func(t *testing.T) *sdk.Stage { - t.Helper() - - id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, "stage_test") - co := sdk.NewStageCopyOptionsRequest().WithOnError(sdk.NewStageCopyOnErrorOptionsRequest().WithSkipFile()) - cr := sdk.NewCreateInternalStageRequest(id).WithCopyOptions(co) - err := client.Stages.CreateInternal(ctx, cr) - require.NoError(t, err) - t.Cleanup(func() { - err = client.Stages.Drop(ctx, sdk.NewDropStageRequest(id)) - require.NoError(t, err) - }) - - e, err := client.Stages.ShowByID(ctx, id) - require.NoError(t, err) - return e - } - - putOnStageHandle := func(t *testing.T, id sdk.SchemaObjectIdentifier, name string) { - t.Helper() - - tempFile := fmt.Sprintf("/tmp/%s", name) - f, err := os.Create(tempFile) - require.NoError(t, err) - f.Close() - defer os.Remove(name) - - _, err = client.ExecForTests(ctx, fmt.Sprintf(`PUT file://%s @%s AUTO_COMPRESS = FALSE OVERWRITE = TRUE`, f.Name(), id.FullyQualifiedName())) - require.NoError(t, err) - t.Cleanup(func() { - _, err = client.ExecForTests(ctx, fmt.Sprintf(`REMOVE @%s/%s`, id.FullyQualifiedName(), name)) - require.NoError(t, err) - }) - } - - showApplicationPackageVersion := func(t *testing.T, name string) []ApplicationPackageVersion { - t.Helper() - - var versions []ApplicationPackageVersion - err := client.QueryForTests(ctx, &versions, fmt.Sprintf(`SHOW VERSIONS IN APPLICATION PACKAGE "%s"`, name)) - require.NoError(t, err) - return versions - } - t.Run("alter application package: add, patch and drop version", func(t *testing.T) { e := createApplicationPackageHandle(t) - s := createStageHandle(t) - putOnStageHandle(t, s.ID(), "manifest.yml") - putOnStageHandle(t, s.ID(), "setup.sql") + stage, stageCleanup := testClientHelper().Stage.CreateStage(t) + t.Cleanup(stageCleanup) + testClientHelper().Stage.PutOnStageWithContent(t, stage.ID(), "manifest.yml", "") + testClientHelper().Stage.PutOnStageWithContent(t, stage.ID(), "setup.sql", "") version := "V001" - using := "@" + s.ID().FullyQualifiedName() + using := "@" + stage.ID().FullyQualifiedName() // add version to application package id := sdk.NewAccountObjectIdentifier(e.Name) vr := sdk.NewAddVersionRequest(using).WithVersionIdentifier(&version).WithLabel(sdk.String("add version V001")) r1 := sdk.NewAlterApplicationPackageRequest(id).WithAddVersion(vr) err := client.ApplicationPackages.Alter(ctx, r1) require.NoError(t, err) - versions := showApplicationPackageVersion(t, e.Name) + versions := testClientHelper().ApplicationPackage.ShowVersions(t, e.ID()) require.Equal(t, 1, len(versions)) require.Equal(t, version, versions[0].Version) require.Equal(t, 0, versions[0].Patch) @@ -290,7 +234,7 @@ func TestInt_ApplicationPackagesVersionAndReleaseDirective(t *testing.T) { r2 := sdk.NewAlterApplicationPackageRequest(id).WithAddPatchForVersion(pr) err = client.ApplicationPackages.Alter(ctx, r2) require.NoError(t, err) - versions = showApplicationPackageVersion(t, e.Name) + versions = testClientHelper().ApplicationPackage.ShowVersions(t, e.ID()) require.Equal(t, 2, len(versions)) require.Equal(t, version, versions[0].Version) require.Equal(t, 0, versions[0].Patch) @@ -301,25 +245,26 @@ func TestInt_ApplicationPackagesVersionAndReleaseDirective(t *testing.T) { r3 := sdk.NewAlterApplicationPackageRequest(id).WithDropVersion(sdk.NewDropVersionRequest(version)) err = client.ApplicationPackages.Alter(ctx, r3) require.NoError(t, err) - versions = showApplicationPackageVersion(t, e.Name) + versions = testClientHelper().ApplicationPackage.ShowVersions(t, e.ID()) require.Equal(t, 0, len(versions)) }) t.Run("alter application package: set default release directive", func(t *testing.T) { e := createApplicationPackageHandle(t) - s := createStageHandle(t) - putOnStageHandle(t, s.ID(), "manifest.yml") - putOnStageHandle(t, s.ID(), "setup.sql") + stage, stageCleanup := testClientHelper().Stage.CreateStage(t) + t.Cleanup(stageCleanup) + testClientHelper().Stage.PutOnStageWithContent(t, stage.ID(), "manifest.yml", "") + testClientHelper().Stage.PutOnStageWithContent(t, stage.ID(), "setup.sql", "") version := "V001" - using := "@" + s.ID().FullyQualifiedName() + using := "@" + stage.ID().FullyQualifiedName() // add version to application package id := sdk.NewAccountObjectIdentifier(e.Name) vr := sdk.NewAddVersionRequest(using).WithVersionIdentifier(&version).WithLabel(sdk.String("add version V001")) r1 := sdk.NewAlterApplicationPackageRequest(id).WithAddVersion(vr) err := client.ApplicationPackages.Alter(ctx, r1) require.NoError(t, err) - versions := showApplicationPackageVersion(t, e.Name) + versions := testClientHelper().ApplicationPackage.ShowVersions(t, e.ID()) require.Equal(t, 1, len(versions)) require.Equal(t, version, versions[0].Version) require.Equal(t, 0, versions[0].Patch) diff --git a/pkg/sdk/testint/database_role_integration_test.go b/pkg/sdk/testint/database_role_integration_test.go index 17c969ce42..0f34fda4a3 100644 --- a/pkg/sdk/testint/database_role_integration_test.go +++ b/pkg/sdk/testint/database_role_integration_test.go @@ -259,7 +259,7 @@ func TestInt_DatabaseRoles(t *testing.T) { role := createDatabaseRole(t) roleId := sdk.NewDatabaseObjectIdentifier(testDb(t).Name, role.Name) - share, shareCleanup := createShare(t, client) + share, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{Database: testDb(t).ID()}, share.ID()) diff --git a/pkg/sdk/testint/databases_integration_test.go b/pkg/sdk/testint/databases_integration_test.go index d8d39a8011..82a832de1d 100644 --- a/pkg/sdk/testint/databases_integration_test.go +++ b/pkg/sdk/testint/databases_integration_test.go @@ -116,7 +116,7 @@ func TestInt_CreateShared(t *testing.T) { databaseTest, databaseCleanup := secondaryTestClientHelper().Database.CreateDatabase(t) t.Cleanup(databaseCleanup) - shareTest, shareCleanup := createShare(t, secondaryClient) + shareTest, shareCleanup := secondaryTestClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := secondaryClient.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ @@ -131,7 +131,7 @@ func TestInt_CreateShared(t *testing.T) { }) accountsToSet := []sdk.AccountIdentifier{ - getAccountIdentifier(t, client), + testClientHelper().Account.GetAccountIdentifier(t), } // first add the account. @@ -275,7 +275,6 @@ func TestInt_AlterReplication(t *testing.T) { } func TestInt_AlterFailover(t *testing.T) { - client := testClient(t) secondaryClient := testSecondaryClient(t) ctx := testContext(t) @@ -283,7 +282,7 @@ func TestInt_AlterFailover(t *testing.T) { t.Cleanup(databaseCleanup) toAccounts := []sdk.AccountIdentifier{ - getAccountIdentifier(t, client), + testClientHelper().Account.GetAccountIdentifier(t), } t.Run("enable and disable failover", func(t *testing.T) { diff --git a/pkg/sdk/testint/event_tables_integration_test.go b/pkg/sdk/testint/event_tables_integration_test.go index 74e67f1265..cc5eef4cae 100644 --- a/pkg/sdk/testint/event_tables_integration_test.go +++ b/pkg/sdk/testint/event_tables_integration_test.go @@ -215,9 +215,9 @@ func TestInt_EventTables(t *testing.T) { // alter view: add and drop row access policies t.Run("alter event table: add and drop row access policies", func(t *testing.T) { - rowAccessPolicyId, rowAccessPolicyCleanup := createRowAccessPolicy(t, client, schemaTest) + rowAccessPolicy, rowAccessPolicyCleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicyCleanup) - rowAccessPolicy2Id, rowAccessPolicy2Cleanup := createRowAccessPolicy(t, client, schemaTest) + rowAccessPolicy2, rowAccessPolicy2Cleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicy2Cleanup) table, tableCleanup := testClientHelper().Table.CreateTable(t) @@ -225,53 +225,53 @@ func TestInt_EventTables(t *testing.T) { id := sdk.NewSchemaObjectIdentifier(table.DatabaseName, table.SchemaName, table.Name) // add policy - alterRequest := sdk.NewAlterEventTableRequest(id).WithAddRowAccessPolicy(sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicyId, []string{"id"})) + alterRequest := sdk.NewAlterEventTableRequest(id).WithAddRowAccessPolicy(sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"id"})) err := client.EventTables.Alter(ctx, alterRequest) require.NoError(t, err) - e, err := getRowAccessPolicyFor(t, client, table.ID(), sdk.ObjectTypeTable) + e, err := testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable) require.NoError(t, err) - assert.Equal(t, rowAccessPolicyId.Name(), e.PolicyName) + assert.Equal(t, rowAccessPolicy.ID().Name(), e.PolicyName) assert.Equal(t, "ROW_ACCESS_POLICY", e.PolicyKind) assert.Equal(t, table.ID().Name(), e.RefEntityName) assert.Equal(t, "TABLE", e.RefEntityDomain) assert.Equal(t, "ACTIVE", e.PolicyStatus) // remove policy - alterRequest = sdk.NewAlterEventTableRequest(id).WithDropRowAccessPolicy(sdk.NewEventTableDropRowAccessPolicyRequest(rowAccessPolicyId)) + alterRequest = sdk.NewAlterEventTableRequest(id).WithDropRowAccessPolicy(sdk.NewEventTableDropRowAccessPolicyRequest(rowAccessPolicy.ID())) err = client.EventTables.Alter(ctx, alterRequest) require.NoError(t, err) - _, err = getRowAccessPolicyFor(t, client, table.ID(), sdk.ObjectTypeTable) + _, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable) require.Error(t, err, "no rows in result set") // add policy again - alterRequest = sdk.NewAlterEventTableRequest(id).WithAddRowAccessPolicy(sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicyId, []string{"id"})) + alterRequest = sdk.NewAlterEventTableRequest(id).WithAddRowAccessPolicy(sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"id"})) err = client.EventTables.Alter(ctx, alterRequest) require.NoError(t, err) - e, err = getRowAccessPolicyFor(t, client, table.ID(), sdk.ObjectTypeTable) + e, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable) require.NoError(t, err) - assert.Equal(t, rowAccessPolicyId.Name(), e.PolicyName) + assert.Equal(t, rowAccessPolicy.ID().Name(), e.PolicyName) // drop and add other policy simultaneously alterRequest = sdk.NewAlterEventTableRequest(id).WithDropAndAddRowAccessPolicy(sdk.NewEventTableDropAndAddRowAccessPolicyRequest( - *sdk.NewEventTableDropRowAccessPolicyRequest(rowAccessPolicyId), - *sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicy2Id, []string{"id"}), + *sdk.NewEventTableDropRowAccessPolicyRequest(rowAccessPolicy.ID()), + *sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicy2.ID(), []string{"id"}), )) err = client.EventTables.Alter(ctx, alterRequest) require.NoError(t, err) - e, err = getRowAccessPolicyFor(t, client, table.ID(), sdk.ObjectTypeTable) + e, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable) require.NoError(t, err) - assert.Equal(t, rowAccessPolicy2Id.Name(), e.PolicyName) + assert.Equal(t, rowAccessPolicy2.ID().Name(), e.PolicyName) // drop all policies alterRequest = sdk.NewAlterEventTableRequest(id).WithDropAllRowAccessPolicies(sdk.Bool(true)) err = client.EventTables.Alter(ctx, alterRequest) require.NoError(t, err) - _, err = getRowAccessPolicyFor(t, client, table.ID(), sdk.ObjectTypeView) + _, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeView) require.Error(t, err, "no rows in result set") }) } diff --git a/pkg/sdk/testint/external_functions_integration_test.go b/pkg/sdk/testint/external_functions_integration_test.go index a6ff57a3d7..5eaf233785 100644 --- a/pkg/sdk/testint/external_functions_integration_test.go +++ b/pkg/sdk/testint/external_functions_integration_test.go @@ -18,7 +18,7 @@ func TestInt_ExternalFunctions(t *testing.T) { databaseTest, schemaTest := testDb(t), testSchema(t) - integration, integrationCleanup := createApiIntegration(t, client) + integration, integrationCleanup := testClientHelper().ApiIntegration.CreateApiIntegration(t) t.Cleanup(integrationCleanup) cleanupExternalFunctionHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { @@ -33,7 +33,7 @@ func TestInt_ExternalFunctions(t *testing.T) { id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, random.StringN(4), defaultDataTypes) argument := sdk.NewExternalFunctionArgumentRequest("x", defaultDataTypes[0]) as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" - request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, &integration, as). + request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as). WithOrReplace(sdk.Bool(true)). WithSecure(sdk.Bool(true)). WithArguments([]sdk.ExternalFunctionArgumentRequest{*argument}) @@ -96,7 +96,7 @@ func TestInt_ExternalFunctions(t *testing.T) { }, } as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" - request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, &integration, as). + request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as). WithOrReplace(sdk.Bool(true)). WithSecure(sdk.Bool(true)). WithArguments([]sdk.ExternalFunctionArgumentRequest{*argument}). @@ -115,7 +115,7 @@ func TestInt_ExternalFunctions(t *testing.T) { t.Run("create external function without arguments", func(t *testing.T) { id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, random.StringN(4), nil) as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" - request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, &integration, as) + request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as) err := client.ExternalFunctions.Create(ctx, request) require.NoError(t, err) t.Cleanup(cleanupExternalFunctionHandle(id, nil)) @@ -127,7 +127,7 @@ func TestInt_ExternalFunctions(t *testing.T) { e := createExternalFunction(t) id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, e.Name, defaultDataTypes) set := sdk.NewExternalFunctionSetRequest(). - WithApiIntegration(&integration) + WithApiIntegration(sdk.Pointer(integration.ID())) request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) err := client.ExternalFunctions.Alter(ctx, request) require.NoError(t, err) diff --git a/pkg/sdk/testint/external_tables_integration_test.go b/pkg/sdk/testint/external_tables_integration_test.go index da2a1f683b..a137368c35 100644 --- a/pkg/sdk/testint/external_tables_integration_test.go +++ b/pkg/sdk/testint/external_tables_integration_test.go @@ -18,7 +18,7 @@ func TestInt_ExternalTables(t *testing.T) { stageID := sdk.NewSchemaObjectIdentifier(TestDatabaseName, TestSchemaName, "EXTERNAL_TABLE_STAGE") stageLocation := fmt.Sprintf("@%s", stageID.FullyQualifiedName()) - _, stageCleanup := testClientHelper().Stage.CreateStageWithURL(t, stageID, nycWeatherDataURL) + _, stageCleanup := testClientHelper().Stage.CreateStageWithURL(t, stageID) t.Cleanup(stageCleanup) tag, tagCleanup := testClientHelper().Tag.CreateTag(t) @@ -425,7 +425,7 @@ func TestInt_ExternalTablesShowByID(t *testing.T) { databaseTest, schemaTest := testDb(t), testSchema(t) stage := sdk.NewSchemaObjectIdentifier(TestDatabaseName, TestSchemaName, random.AlphaN(6)) - _, stageCleanup := testClientHelper().Stage.CreateStageWithURL(t, stage, nycWeatherDataURL) + _, stageCleanup := testClientHelper().Stage.CreateStageWithURL(t, stage) t.Cleanup(stageCleanup) stageLocation := fmt.Sprintf("@%s", stage.FullyQualifiedName()) diff --git a/pkg/sdk/testint/failover_groups_integration_test.go b/pkg/sdk/testint/failover_groups_integration_test.go index f8a1d89007..6961d1e309 100644 --- a/pkg/sdk/testint/failover_groups_integration_test.go +++ b/pkg/sdk/testint/failover_groups_integration_test.go @@ -20,7 +20,7 @@ func TestInt_FailoverGroupsCreate(t *testing.T) { client := testClient(t) ctx := testContext(t) - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) accountName := testenvs.GetOrSkipTest(t, testenvs.BusinessCriticalAccount) @@ -81,7 +81,7 @@ func TestInt_FailoverGroupsCreate(t *testing.T) { t.Run("test with identifier containing a dot", func(t *testing.T) { shareId := sdk.NewAccountObjectIdentifier(random.AlphanumericN(6) + "." + random.AlphanumericN(6)) - shareWithDot, shareWithDotCleanup := createShareWithOptions(t, client, shareId, &sdk.CreateShareOptions{}) + shareWithDot, shareWithDotCleanup := testClientHelper().Share.CreateShareWithOptions(t, shareId, &sdk.CreateShareOptions{}) t.Cleanup(shareWithDotCleanup) id := sdk.RandomAccountObjectIdentifier() @@ -224,12 +224,12 @@ func TestInt_CreateSecondaryReplicationGroup(t *testing.T) { client := testClient(t) ctx := testContext(t) - primaryAccountID := getAccountIdentifier(t, client) + primaryAccountID := testClientHelper().Account.GetAccountIdentifier(t) secondaryClient := testSecondaryClient(t) - secondaryClientID := getAccountIdentifier(t, secondaryClient) + secondaryClientID := secondaryTestClientHelper().Account.GetAccountIdentifier(t) // create a temp share - shareTest, cleanupDatabase := createShare(t, client) + shareTest, cleanupDatabase := testClientHelper().Share.CreateShare(t) t.Cleanup(cleanupDatabase) // create a failover group in primary account and share with target account @@ -400,7 +400,7 @@ func TestInt_FailoverGroupsAlterSource(t *testing.T) { }) t.Run("add and remove share account object", func(t *testing.T) { - shareTest, cleanupDatabase := createShare(t, client) + shareTest, cleanupDatabase := testClientHelper().Share.CreateShare(t) t.Cleanup(cleanupDatabase) failoverGroup, cleanupFailoverGroup := testClientHelper().FailoverGroup.CreateFailoverGroup(t) t.Cleanup(cleanupFailoverGroup) @@ -497,7 +497,7 @@ func TestInt_FailoverGroupsAlterSource(t *testing.T) { failoverGroup, cleanupFailoverGroup := testClientHelper().FailoverGroup.CreateFailoverGroup(t) t.Cleanup(cleanupFailoverGroup) - secondaryAccountID := getAccountIdentifier(t, testSecondaryClient(t)) + secondaryAccountID := secondaryTestClientHelper().Account.GetAccountIdentifier(t) // first add target account opts := &sdk.AlterSourceFailoverGroupOptions{ Add: &sdk.FailoverGroupAdd{ @@ -526,7 +526,7 @@ func TestInt_FailoverGroupsAlterSource(t *testing.T) { failoverGroup, err = client.FailoverGroups.ShowByID(ctx, failoverGroup.ID()) require.NoError(t, err) assert.Equal(t, 1, len(failoverGroup.AllowedAccounts)) - assert.Contains(t, failoverGroup.AllowedAccounts, getAccountIdentifier(t, client)) + assert.Contains(t, failoverGroup.AllowedAccounts, testClientHelper().Account.GetAccountIdentifier(t)) }) t.Run("move shares to another failover group", func(t *testing.T) { @@ -551,7 +551,7 @@ func TestInt_FailoverGroupsAlterSource(t *testing.T) { require.NoError(t, err) // create a temp share - shareTest, cleanupShare := createShare(t, client) + shareTest, cleanupShare := testClientHelper().Share.CreateShare(t) t.Cleanup(cleanupShare) // now add share to allowed shares of failover group 1 @@ -654,9 +654,9 @@ func TestInt_FailoverGroupsAlterTarget(t *testing.T) { client := testClient(t) ctx := testContext(t) - primaryAccountID := getAccountIdentifier(t, client) + primaryAccountID := testClientHelper().Account.GetAccountIdentifier(t) secondaryClient := testSecondaryClient(t) - secondaryClientID := getAccountIdentifier(t, secondaryClient) + secondaryClientID := secondaryTestClientHelper().Account.GetAccountIdentifier(t) // create a temp database databaseTest, cleanupDatabase := testClientHelper().Database.CreateDatabase(t) @@ -871,7 +871,7 @@ func TestInt_FailoverGroupsShowShares(t *testing.T) { failoverGroupTest, failoverGroupCleanup := testClientHelper().FailoverGroup.CreateFailoverGroup(t) t.Cleanup(failoverGroupCleanup) - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) opts := &sdk.AlterSourceFailoverGroupOptions{ Set: &sdk.FailoverGroupSet{ diff --git a/pkg/sdk/testint/grants_integration_test.go b/pkg/sdk/testint/grants_integration_test.go index 8962ce90d9..62c2933a72 100644 --- a/pkg/sdk/testint/grants_integration_test.go +++ b/pkg/sdk/testint/grants_integration_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" "github.com/stretchr/testify/assert" @@ -709,7 +708,7 @@ func TestInt_GrantAndRevokePrivilegesToDatabaseRole(t *testing.T) { func TestInt_GrantPrivilegeToShare(t *testing.T) { client := testClient(t) ctx := testContext(t) - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) assertGrant := func(t *testing.T, grants []sdk.Grant, onId sdk.ObjectIdentifier, privilege sdk.ObjectPrivilege) { @@ -797,7 +796,7 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { func TestInt_RevokePrivilegeToShare(t *testing.T) { client := testClient(t) ctx := testContext(t) - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Database: testDb(t).ID(), @@ -1428,7 +1427,7 @@ func TestInt_GrantOwnership(t *testing.T) { }) t.Run("on task - with ownership", func(t *testing.T) { - task, taskCleanup := createTask(t, client, testDb(t), testSchema(t)) + task, taskCleanup := testClientHelper().Task.CreateTask(t) t.Cleanup(taskCleanup) err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(task.ID()).WithResume(sdk.Bool(true))) @@ -1473,9 +1472,7 @@ func TestInt_GrantOwnership(t *testing.T) { // Use a previously prepared role to create a task usePreviousRole := testClientHelper().Role.UseRole(t, taskRole.Name) - taskId := sdk.NewSchemaObjectIdentifier(TestDatabaseName, TestSchemaName, random.AlphaN(20)) - withWarehouseReq := sdk.NewCreateTaskWarehouseRequest().WithWarehouse(sdk.Pointer(testWarehouse(t).ID())) - task, taskCleanup := createTaskWithRequest(t, client, sdk.NewCreateTaskRequest(taskId, "SELECT CURRENT_TIMESTAMP").WithWarehouse(withWarehouseReq).WithSchedule(sdk.String("60 minutes"))) + task, taskCleanup := testClientHelper().Task.CreateTask(t) t.Cleanup(func() { usePreviousRole := testClientHelper().Role.UseRole(t, taskRole.Name) defer usePreviousRole() @@ -1521,9 +1518,7 @@ func TestInt_GrantOwnership(t *testing.T) { // Use a previously prepared role to create a task usePreviousRole := testClientHelper().Role.UseRole(t, taskRole.Name) - taskId := sdk.NewSchemaObjectIdentifier(TestDatabaseName, TestSchemaName, random.AlphaN(20)) - withWarehouseReq := sdk.NewCreateTaskWarehouseRequest().WithWarehouse(sdk.Pointer(testWarehouse(t).ID())) - task, taskCleanup := createTaskWithRequest(t, client, sdk.NewCreateTaskRequest(taskId, "SELECT CURRENT_TIMESTAMP").WithWarehouse(withWarehouseReq).WithSchedule(sdk.String("60 minutes"))) + task, taskCleanup := testClientHelper().Task.CreateTask(t) t.Cleanup(taskCleanup) err := client.Grants.GrantPrivilegesToAccountRole( @@ -1579,13 +1574,13 @@ func TestInt_GrantOwnership(t *testing.T) { }) t.Run("on all tasks - with ownership", func(t *testing.T) { - task, taskCleanup := createTask(t, client, testDb(t), testSchema(t)) + task, taskCleanup := testClientHelper().Task.CreateTask(t) t.Cleanup(taskCleanup) err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(task.ID()).WithResume(sdk.Bool(true))) require.NoError(t, err) - secondTask, secondTaskCleanup := createTask(t, client, testDb(t), testSchema(t)) + secondTask, secondTaskCleanup := testClientHelper().Task.CreateTask(t) t.Cleanup(secondTaskCleanup) err = client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(secondTask.ID()).WithResume(sdk.Bool(true))) @@ -1651,13 +1646,10 @@ func TestInt_GrantOwnership(t *testing.T) { // Use a previously prepared role to create a task usePreviousRole := testClientHelper().Role.UseRole(t, taskRole.Name) - taskId := sdk.NewSchemaObjectIdentifier(TestDatabaseName, TestSchemaName, random.AlphaN(20)) - withWarehouseReq := sdk.NewCreateTaskWarehouseRequest().WithWarehouse(sdk.Pointer(testWarehouse(t).ID())) - task, taskCleanup := createTaskWithRequest(t, client, sdk.NewCreateTaskRequest(taskId, "SELECT CURRENT_TIMESTAMP").WithWarehouse(withWarehouseReq).WithSchedule(sdk.String("60 minutes"))) + task, taskCleanup := testClientHelper().Task.CreateTask(t) t.Cleanup(taskCleanup) - secondTaskId := sdk.NewSchemaObjectIdentifier(TestDatabaseName, TestSchemaName, random.AlphaN(20)) - secondTask, secondTaskCleanup := createTaskWithRequest(t, client, sdk.NewCreateTaskRequest(secondTaskId, "SELECT CURRENT_TIMESTAMP").WithWarehouse(withWarehouseReq).WithAfter([]sdk.SchemaObjectIdentifier{task.ID()})) + secondTask, secondTaskCleanup := testClientHelper().Task.CreateTaskWithAfter(t, task.ID()) t.Cleanup(secondTaskCleanup) err := client.Grants.GrantPrivilegesToAccountRole( @@ -1760,7 +1752,7 @@ func TestInt_GrantOwnership(t *testing.T) { func TestInt_ShowGrants(t *testing.T) { client := testClient(t) ctx := testContext(t) - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Database: testDb(t).ID(), diff --git a/pkg/sdk/testint/helpers_test.go b/pkg/sdk/testint/helpers_test.go deleted file mode 100644 index 7befa3a40e..0000000000 --- a/pkg/sdk/testint/helpers_test.go +++ /dev/null @@ -1,251 +0,0 @@ -package testint - -import ( - "context" - "database/sql" - "fmt" - "testing" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - "github.com/stretchr/testify/require" -) - -const ( - nycWeatherDataURL = "s3://snowflake-workshop-lab/weather-nyc" -) - -// there is no direct way to get the account identifier from Snowflake API, but you can get it if you know -// the account locator and by filtering the list of accounts in replication accounts by the account locator -func getAccountIdentifier(t *testing.T, client *sdk.Client) sdk.AccountIdentifier { - t.Helper() - ctx := context.Background() - // TODO: replace later (incoming clients differ) - currentAccountLocator, err := client.ContextFunctions.CurrentAccount(ctx) - require.NoError(t, err) - replicationAccounts, err := client.ReplicationFunctions.ShowReplicationAccounts(ctx) - require.NoError(t, err) - for _, replicationAccount := range replicationAccounts { - if replicationAccount.AccountLocator == currentAccountLocator { - return sdk.NewAccountIdentifier(replicationAccount.OrganizationName, replicationAccount.AccountName) - } - } - return sdk.AccountIdentifier{} -} - -func createShare(t *testing.T, client *sdk.Client) (*sdk.Share, func()) { - t.Helper() - // TODO(SNOW-1058419): Try with identifier containing dot during identifiers rework - id := sdk.RandomAlphanumericAccountObjectIdentifier() - return createShareWithOptions(t, client, id, &sdk.CreateShareOptions{}) -} - -func createShareWithOptions(t *testing.T, client *sdk.Client, id sdk.AccountObjectIdentifier, opts *sdk.CreateShareOptions) (*sdk.Share, func()) { - t.Helper() - ctx := context.Background() - err := client.Shares.Create(ctx, id, opts) - require.NoError(t, err) - share, err := client.Shares.ShowByID(ctx, id) - require.NoError(t, err) - return share, func() { - err := client.Shares.Drop(ctx, id) - require.NoError(t, err) - } -} - -func createView(t *testing.T, client *sdk.Client, viewId sdk.SchemaObjectIdentifier, asQuery string) func() { - t.Helper() - ctx := context.Background() - _, err := client.ExecForTests(ctx, fmt.Sprintf(`CREATE VIEW %s AS %s`, viewId.FullyQualifiedName(), asQuery)) - require.NoError(t, err) - - return func() { - _, err := client.ExecForTests(ctx, fmt.Sprintf(`DROP VIEW %s`, viewId.FullyQualifiedName())) - require.NoError(t, err) - } -} - -func createRowAccessPolicy(t *testing.T, client *sdk.Client, schema *sdk.Schema) (sdk.SchemaObjectIdentifier, func()) { - t.Helper() - ctx := context.Background() - id := sdk.NewSchemaObjectIdentifier(schema.DatabaseName, schema.Name, random.String()) - - arg := sdk.NewCreateRowAccessPolicyArgsRequest("A", sdk.DataTypeNumber) - body := "true" - createRequest := sdk.NewCreateRowAccessPolicyRequest(id, []sdk.CreateRowAccessPolicyArgsRequest{*arg}, body) - err := client.RowAccessPolicies.Create(ctx, createRequest) - require.NoError(t, err) - - return id, func() { - err := client.RowAccessPolicies.Drop(ctx, sdk.NewDropRowAccessPolicyRequest(id)) - require.NoError(t, err) - } -} - -func createApiIntegration(t *testing.T, client *sdk.Client) (sdk.AccountObjectIdentifier, func()) { - t.Helper() - ctx := context.Background() - id := sdk.NewAccountObjectIdentifier(random.String()) - apiAllowedPrefixes := []sdk.ApiIntegrationEndpointPrefix{{Path: "https://xyz.execute-api.us-west-2.amazonaws.com/production"}} - req := sdk.NewCreateApiIntegrationRequest(id, apiAllowedPrefixes, true) - req.WithAwsApiProviderParams(sdk.NewAwsApiParamsRequest(sdk.ApiIntegrationAwsApiGateway, "arn:aws:iam::123456789012:role/hello_cloud_account_role")) - err := client.ApiIntegrations.Create(ctx, req) - require.NoError(t, err) - - return id, func() { - err := client.ApiIntegrations.Drop(ctx, sdk.NewDropApiIntegrationRequest(id)) - require.NoError(t, err) - } -} - -func createExternalFunction(t *testing.T, client *sdk.Client, schema *sdk.Schema) (sdk.SchemaObjectIdentifier, func()) { - t.Helper() - ctx := context.Background() - apiIntegration, cleanupApiIntegration := createApiIntegration(t, client) - id := sdk.NewSchemaObjectIdentifier(schema.DatabaseName, schema.Name, random.StringN(4)) - argument := sdk.NewExternalFunctionArgumentRequest("x", sdk.DataTypeVARCHAR) - argumentsRequest := []sdk.ExternalFunctionArgumentRequest{*argument} - as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" - request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, &apiIntegration, as). - WithOrReplace(sdk.Bool(true)). - WithArguments(argumentsRequest) - err := client.ExternalFunctions.Create(ctx, request) - require.NoError(t, err) - return id, func() { - cleanupApiIntegration() - err = client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id, []sdk.DataType{sdk.DataTypeVARCHAR})) - require.NoError(t, err) - } -} - -// TODO: extract getting row access policies as resource (like getting tag in system functions) -// getRowAccessPolicyFor is based on https://docs.snowflake.com/en/user-guide/security-row-intro#obtain-database-objects-with-a-row-access-policy. -func getRowAccessPolicyFor(t *testing.T, client *sdk.Client, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) (*policyReference, error) { - t.Helper() - ctx := context.Background() - - s := &policyReference{} - policyReferencesId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), "INFORMATION_SCHEMA", "POLICY_REFERENCES") - err := client.QueryOneForTests(ctx, s, fmt.Sprintf(`SELECT * FROM TABLE(%s(REF_ENTITY_NAME => '%s', REF_ENTITY_DOMAIN => '%v'))`, policyReferencesId.FullyQualifiedName(), id.FullyQualifiedName(), objectType)) - - return s, err -} - -type policyReference struct { - PolicyDb string `db:"POLICY_DB"` - PolicySchema string `db:"POLICY_SCHEMA"` - PolicyName string `db:"POLICY_NAME"` - PolicyKind string `db:"POLICY_KIND"` - RefDatabaseName string `db:"REF_DATABASE_NAME"` - RefSchemaName string `db:"REF_SCHEMA_NAME"` - RefEntityName string `db:"REF_ENTITY_NAME"` - RefEntityDomain string `db:"REF_ENTITY_DOMAIN"` - RefColumnName sql.NullString `db:"REF_COLUMN_NAME"` - RefArgColumnNames string `db:"REF_ARG_COLUMN_NAMES"` - TagDatabase sql.NullString `db:"TAG_DATABASE"` - TagSchema sql.NullString `db:"TAG_SCHEMA"` - TagName sql.NullString `db:"TAG_NAME"` - PolicyStatus string `db:"POLICY_STATUS"` -} - -// TODO: extract getting table columns as resource (like getting tag in system functions) -// getTableColumnsFor is based on https://docs.snowflake.com/en/sql-reference/info-schema/columns. -func getTableColumnsFor(t *testing.T, client *sdk.Client, tableId sdk.SchemaObjectIdentifier) []informationSchemaColumns { - t.Helper() - ctx := context.Background() - - var columns []informationSchemaColumns - query := fmt.Sprintf("SELECT * FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s' ORDER BY ordinal_position", tableId.SchemaName(), tableId.Name()) - err := client.QueryForTests(ctx, &columns, query) - require.NoError(t, err) - - return columns -} - -type informationSchemaColumns struct { - TableCatalog string `db:"TABLE_CATALOG"` - TableSchema string `db:"TABLE_SCHEMA"` - TableName string `db:"TABLE_NAME"` - ColumnName string `db:"COLUMN_NAME"` - OrdinalPosition string `db:"ORDINAL_POSITION"` - ColumnDefault sql.NullString `db:"COLUMN_DEFAULT"` - IsNullable string `db:"IS_NULLABLE"` - DataType string `db:"DATA_TYPE"` - CharacterMaximumLength sql.NullString `db:"CHARACTER_MAXIMUM_LENGTH"` - CharacterOctetLength sql.NullString `db:"CHARACTER_OCTET_LENGTH"` - NumericPrecision sql.NullString `db:"NUMERIC_PRECISION"` - NumericPrecisionRadix sql.NullString `db:"NUMERIC_PRECISION_RADIX"` - NumericScale sql.NullString `db:"NUMERIC_SCALE"` - DatetimePrecision sql.NullString `db:"DATETIME_PRECISION"` - IntervalType sql.NullString `db:"INTERVAL_TYPE"` - IntervalPrecision sql.NullString `db:"INTERVAL_PRECISION"` - CharacterSetCatalog sql.NullString `db:"CHARACTER_SET_CATALOG"` - CharacterSetSchema sql.NullString `db:"CHARACTER_SET_SCHEMA"` - CharacterSetName sql.NullString `db:"CHARACTER_SET_NAME"` - CollationCatalog sql.NullString `db:"COLLATION_CATALOG"` - CollationSchema sql.NullString `db:"COLLATION_SCHEMA"` - CollationName sql.NullString `db:"COLLATION_NAME"` - DomainCatalog sql.NullString `db:"DOMAIN_CATALOG"` - DomainSchema sql.NullString `db:"DOMAIN_SCHEMA"` - DomainName sql.NullString `db:"DOMAIN_NAME"` - UdtCatalog sql.NullString `db:"UDT_CATALOG"` - UdtSchema sql.NullString `db:"UDT_SCHEMA"` - UdtName sql.NullString `db:"UDT_NAME"` - ScopeCatalog sql.NullString `db:"SCOPE_CATALOG"` - ScopeSchema sql.NullString `db:"SCOPE_SCHEMA"` - ScopeName sql.NullString `db:"SCOPE_NAME"` - MaximumCardinality sql.NullString `db:"MAXIMUM_CARDINALITY"` - DtdIdentifier sql.NullString `db:"DTD_IDENTIFIER"` - IsSelfReferencing string `db:"IS_SELF_REFERENCING"` - IsIdentity string `db:"IS_IDENTITY"` - IdentityGeneration sql.NullString `db:"IDENTITY_GENERATION"` - IdentityStart sql.NullString `db:"IDENTITY_START"` - IdentityIncrement sql.NullString `db:"IDENTITY_INCREMENT"` - IdentityMaximum sql.NullString `db:"IDENTITY_MAXIMUM"` - IdentityMinimum sql.NullString `db:"IDENTITY_MINIMUM"` - IdentityCycle sql.NullString `db:"IDENTITY_CYCLE"` - IdentityOrdered sql.NullString `db:"IDENTITY_ORDERED"` - Comment sql.NullString `db:"COMMENT"` -} - -func updateAccountParameterTemporarily(t *testing.T, client *sdk.Client, parameter sdk.AccountParameter, newValue string) func() { - t.Helper() - ctx := context.Background() - - param, err := client.Parameters.ShowAccountParameter(ctx, parameter) - require.NoError(t, err) - oldValue := param.Value - - err = client.Parameters.SetAccountParameter(ctx, parameter, newValue) - require.NoError(t, err) - - return func() { - err = client.Parameters.SetAccountParameter(ctx, parameter, oldValue) - require.NoError(t, err) - } -} - -func createTaskWithRequest(t *testing.T, client *sdk.Client, request *sdk.CreateTaskRequest) (*sdk.Task, func()) { - t.Helper() - ctx := context.Background() - - id := request.GetName() - - err := client.Tasks.Create(ctx, request) - require.NoError(t, err) - - task, err := client.Tasks.ShowByID(ctx, id) - require.NoError(t, err) - - return task, func() { - err = client.Tasks.Drop(ctx, sdk.NewDropTaskRequest(id)) - require.NoError(t, err) - } -} - -func createTask(t *testing.T, client *sdk.Client, database *sdk.Database, schema *sdk.Schema) (*sdk.Task, func()) { - t.Helper() - id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, random.AlphaN(20)) - warehouseReq := sdk.NewCreateTaskWarehouseRequest().WithWarehouse(sdk.Pointer(testWarehouse(t).ID())) - return createTaskWithRequest(t, client, sdk.NewCreateTaskRequest(id, "SELECT CURRENT_TIMESTAMP").WithSchedule(sdk.String("60 minutes")).WithWarehouse(warehouseReq)) -} diff --git a/pkg/sdk/testint/masking_policy_integration_test.go b/pkg/sdk/testint/masking_policy_integration_test.go index f76f9b2910..80d3710219 100644 --- a/pkg/sdk/testint/masking_policy_integration_test.go +++ b/pkg/sdk/testint/masking_policy_integration_test.go @@ -373,7 +373,7 @@ func TestInt_MaskingPolicyDrop(t *testing.T) { maskingPolicy, maskingPolicyCleanup := testClientHelper().MaskingPolicy.CreateMaskingPolicy(t) t.Cleanup(maskingPolicyCleanup) id := maskingPolicy.ID() - err := client.MaskingPolicies.Drop(ctx, id, &sdk.DropMaskingPolicyOptions{}) + err := client.MaskingPolicies.Drop(ctx, id, nil) require.NoError(t, err) _, err = client.MaskingPolicies.Describe(ctx, id) assert.ErrorIs(t, err, sdk.ErrObjectNotExistOrAuthorized) @@ -381,7 +381,7 @@ func TestInt_MaskingPolicyDrop(t *testing.T) { t.Run("when masking policy does not exist", func(t *testing.T) { id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, "does_not_exist") - err := client.MaskingPolicies.Drop(ctx, id, &sdk.DropMaskingPolicyOptions{}) + err := client.MaskingPolicies.Drop(ctx, id, nil) assert.ErrorIs(t, err, sdk.ErrObjectNotExistOrAuthorized) }) } diff --git a/pkg/sdk/testint/materialized_views_gen_integration_test.go b/pkg/sdk/testint/materialized_views_gen_integration_test.go index 501d877988..3bda131a53 100644 --- a/pkg/sdk/testint/materialized_views_gen_integration_test.go +++ b/pkg/sdk/testint/materialized_views_gen_integration_test.go @@ -113,7 +113,7 @@ func TestInt_MaterializedViews(t *testing.T) { }) t.Run("create materialized view: almost complete case", func(t *testing.T) { - rowAccessPolicyId, rowAccessPolicyCleanup := createRowAccessPolicy(t, client, testSchema(t)) + rowAccessPolicy, rowAccessPolicyCleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicyCleanup) tag, tagCleanup := testClientHelper().Tag.CreateTag(t) @@ -127,7 +127,7 @@ func TestInt_MaterializedViews(t *testing.T) { }). WithCopyGrants(sdk.Bool(true)). WithComment(sdk.String("comment")). - WithRowAccessPolicy(sdk.NewMaterializedViewRowAccessPolicyRequest(rowAccessPolicyId, []string{"column_with_comment"})). + WithRowAccessPolicy(sdk.NewMaterializedViewRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"column_with_comment"})). WithTag([]sdk.TagAssociation{{ Name: tag.ID(), Value: "v2", @@ -139,9 +139,9 @@ func TestInt_MaterializedViews(t *testing.T) { view := createMaterializedViewWithRequest(t, request) assertMaterializedViewWithOptions(t, view, id, true, "comment", fmt.Sprintf(`LINEAR("%s")`, "COLUMN_WITH_COMMENT")) - rowAccessPolicyReference, err := getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + rowAccessPolicyReference, err := testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView) require.NoError(t, err) - assert.Equal(t, rowAccessPolicyId.Name(), rowAccessPolicyReference.PolicyName) + assert.Equal(t, rowAccessPolicy.Name, rowAccessPolicyReference.PolicyName) assert.Equal(t, "ROW_ACCESS_POLICY", rowAccessPolicyReference.PolicyKind) assert.Equal(t, view.ID().Name(), rowAccessPolicyReference.RefEntityName) assert.Equal(t, "MATERIALIZED_VIEW", rowAccessPolicyReference.RefEntityDomain) diff --git a/pkg/sdk/testint/resource_monitors_integration_test.go b/pkg/sdk/testint/resource_monitors_integration_test.go index a7f7e30be0..1c21760881 100644 --- a/pkg/sdk/testint/resource_monitors_integration_test.go +++ b/pkg/sdk/testint/resource_monitors_integration_test.go @@ -306,7 +306,7 @@ func TestInt_ResourceMonitorDrop(t *testing.T) { resourceMonitor, resourceMonitorCleanup := testClientHelper().ResourceMonitor.CreateResourceMonitor(t) t.Cleanup(resourceMonitorCleanup) id := resourceMonitor.ID() - err := client.ResourceMonitors.Drop(ctx, id, &sdk.DropResourceMonitorOptions{}) + err := client.ResourceMonitors.Drop(ctx, id, nil) require.NoError(t, err) _, err = client.ResourceMonitors.ShowByID(ctx, id) assert.ErrorIs(t, err, sdk.ErrObjectNotExistOrAuthorized) @@ -314,7 +314,7 @@ func TestInt_ResourceMonitorDrop(t *testing.T) { t.Run("when resource monitor does not exist", func(t *testing.T) { id := sdk.NewAccountObjectIdentifier("does_not_exist") - err := client.ResourceMonitors.Drop(ctx, id, &sdk.DropResourceMonitorOptions{}) + err := client.ResourceMonitors.Drop(ctx, id, nil) assert.ErrorIs(t, err, sdk.ErrObjectNotExistOrAuthorized) }) } diff --git a/pkg/sdk/testint/shares_integration_test.go b/pkg/sdk/testint/shares_integration_test.go index 19891adecd..4f64353892 100644 --- a/pkg/sdk/testint/shares_integration_test.go +++ b/pkg/sdk/testint/shares_integration_test.go @@ -12,10 +12,10 @@ import ( func TestInt_SharesShow(t *testing.T) { client := testClient(t) ctx := testContext(t) - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) - _, shareCleanup2 := createShare(t, client) + _, shareCleanup2 := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup2) t.Run("without show options", func(t *testing.T) { @@ -83,10 +83,7 @@ func TestInt_SharesCreate(t *testing.T) { assert.Equal(t, id.Name(), shares[0].Name.Name()) assert.Equal(t, "test comment", shares[0].Comment) - t.Cleanup(func() { - err := client.Shares.Drop(ctx, id) - require.NoError(t, err) - }) + t.Cleanup(testClientHelper().Share.DropShareFunc(t, id)) }) t.Run("test no options", func(t *testing.T) { @@ -100,10 +97,7 @@ func TestInt_SharesCreate(t *testing.T) { require.NoError(t, err) assert.GreaterOrEqual(t, len(shares), 1) - t.Cleanup(func() { - err := client.Shares.Drop(ctx, id) - require.NoError(t, err) - }) + t.Cleanup(testClientHelper().Share.DropShareFunc(t, id)) }) } @@ -112,13 +106,14 @@ func TestInt_SharesDrop(t *testing.T) { ctx := testContext(t) t.Run("when share exists", func(t *testing.T) { - shareTest, _ := createShare(t, client) - err := client.Shares.Drop(ctx, shareTest.ID()) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) + t.Cleanup(shareCleanup) + err := client.Shares.Drop(ctx, shareTest.ID(), &sdk.DropShareOptions{}) require.NoError(t, err) }) t.Run("when share does not exist", func(t *testing.T) { - err := client.Shares.Drop(ctx, sdk.NewAccountObjectIdentifier("does_not_exist")) + err := client.Shares.Drop(ctx, sdk.NewAccountObjectIdentifier("does_not_exist"), &sdk.DropShareOptions{}) assert.ErrorIs(t, err, sdk.ErrObjectNotExistOrAuthorized) }) } @@ -129,7 +124,7 @@ func TestInt_SharesAlter(t *testing.T) { ctx := testContext(t) t.Run("add and remove accounts", func(t *testing.T) { - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Database: testDb(t).ID(), @@ -142,7 +137,7 @@ func TestInt_SharesAlter(t *testing.T) { }) require.NoError(t, err) accountsToAdd := []sdk.AccountIdentifier{ - getAccountIdentifier(t, secondaryClient), + secondaryTestClientHelper().Account.GetAccountIdentifier(t), } // first add the account. err = client.Shares.Alter(ctx, shareTest.ID(), &sdk.AlterShareOptions{ @@ -186,7 +181,7 @@ func TestInt_SharesAlter(t *testing.T) { db, dbCleanup := secondaryTestClientHelper().Database.CreateDatabase(t) t.Cleanup(dbCleanup) - shareTest, shareCleanup := createShare(t, secondaryClient) + shareTest, shareCleanup := secondaryTestClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := secondaryClient.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ @@ -201,7 +196,7 @@ func TestInt_SharesAlter(t *testing.T) { }) accountsToSet := []sdk.AccountIdentifier{ - getAccountIdentifier(t, client), + testClientHelper().Account.GetAccountIdentifier(t), } // first add the account. @@ -226,7 +221,7 @@ func TestInt_SharesAlter(t *testing.T) { }) t.Run("set and unset comment", func(t *testing.T) { - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ @@ -282,7 +277,7 @@ func TestInt_SharesAlter(t *testing.T) { }) t.Run("set and unset tags", func(t *testing.T) { - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Database: testDb(t).ID(), @@ -342,7 +337,7 @@ func TestInt_ShareDescribeProvider(t *testing.T) { ctx := testContext(t) t.Run("describe share", func(t *testing.T) { - shareTest, shareCleanup := createShare(t, client) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ @@ -375,7 +370,7 @@ func TestInt_ShareDescribeConsumer(t *testing.T) { db, dbCleanup := secondaryTestClientHelper().Database.CreateDatabase(t) t.Cleanup(dbCleanup) - shareTest, shareCleanup := createShare(t, providerClient) + shareTest, shareCleanup := secondaryTestClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) err := providerClient.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ @@ -393,7 +388,7 @@ func TestInt_ShareDescribeConsumer(t *testing.T) { err = providerClient.Shares.Alter(ctx, shareTest.ID(), &sdk.AlterShareOptions{ Add: &sdk.ShareAdd{ Accounts: []sdk.AccountIdentifier{ - getAccountIdentifier(t, consumerClient), + testClientHelper().Account.GetAccountIdentifier(t), }, }, }) diff --git a/pkg/sdk/testint/streams_gen_integration_test.go b/pkg/sdk/testint/streams_gen_integration_test.go index d1adc1fa6d..f5e99fb7a2 100644 --- a/pkg/sdk/testint/streams_gen_integration_test.go +++ b/pkg/sdk/testint/streams_gen_integration_test.go @@ -55,7 +55,7 @@ func TestInt_Streams(t *testing.T) { stageName := random.AlphaN(10) stageID := sdk.NewSchemaObjectIdentifier(TestDatabaseName, TestSchemaName, stageName) stageLocation := fmt.Sprintf("@%s", stageID.FullyQualifiedName()) - _, stageCleanup := testClientHelper().Stage.CreateStageWithURL(t, stageID, nycWeatherDataURL) + _, stageCleanup := testClientHelper().Stage.CreateStageWithURL(t, stageID) t.Cleanup(stageCleanup) externalTableId := sdk.NewSchemaObjectIdentifier(db.Name, schema.Name, random.AlphanumericN(32)) @@ -107,12 +107,11 @@ func TestInt_Streams(t *testing.T) { tableId := sdk.NewSchemaObjectIdentifier(db.Name, schema.Name, table.Name) t.Cleanup(cleanupTable) - viewId := sdk.NewSchemaObjectIdentifier(db.Name, schema.Name, random.AlphanumericN(32)) - cleanupView := createView(t, client, viewId, fmt.Sprintf("SELECT id FROM %s", tableId.FullyQualifiedName())) + view, cleanupView := testClientHelper().View.CreateView(t, fmt.Sprintf("SELECT id FROM %s", tableId.FullyQualifiedName())) t.Cleanup(cleanupView) id := sdk.NewSchemaObjectIdentifier(db.Name, schema.Name, random.AlphanumericN(32)) - req := sdk.NewCreateStreamOnViewRequest(id, viewId).WithComment(sdk.String("some comment")) + req := sdk.NewCreateStreamOnViewRequest(id, view.ID()).WithComment(sdk.String("some comment")) err := client.Streams.CreateOnView(ctx, req) require.NoError(t, err) t.Cleanup(func() { diff --git a/pkg/sdk/testint/tables_integration_test.go b/pkg/sdk/testint/tables_integration_test.go index 2a67278eb3..030e9412f6 100644 --- a/pkg/sdk/testint/tables_integration_test.go +++ b/pkg/sdk/testint/tables_integration_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" @@ -38,7 +39,7 @@ func TestInt_Table(t *testing.T) { tag2, tagCleanup2 := testClientHelper().Tag.CreateTag(t) t.Cleanup(tagCleanup2) - assertColumns := func(t *testing.T, expectedColumns []expectedColumn, createdColumns []informationSchemaColumns) { + assertColumns := func(t *testing.T, expectedColumns []expectedColumn, createdColumns []helpers.InformationSchemaColumns) { t.Helper() require.Len(t, createdColumns, len(expectedColumns)) @@ -154,7 +155,7 @@ func TestInt_Table(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "30", param.Value) - tableColumns := getTableColumnsFor(t, client, table.ID()) + tableColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_3", sdk.DataTypeVARCHAR}, {"COLUMN_1", sdk.DataTypeVARCHAR}, @@ -193,7 +194,7 @@ func TestInt_Table(t *testing.T) { table, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - tableColumns := getTableColumnsFor(t, client, table.ID()) + tableColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_3", sdk.DataTypeVARCHAR}, {"COLUMN_1", sdk.DataTypeVARCHAR}, @@ -234,7 +235,7 @@ func TestInt_Table(t *testing.T) { table, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - returnedTableColumns := getTableColumnsFor(t, client, table.ID()) + returnedTableColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"C1", sdk.DataTypeVARCHAR}, {"C2", sdk.DataTypeVARCHAR}, @@ -261,7 +262,7 @@ func TestInt_Table(t *testing.T) { require.NoError(t, err) t.Cleanup(cleanupTableProvider(id)) - sourceTableColumns := getTableColumnsFor(t, client, sourceTable.ID()) + sourceTableColumns := testClientHelper().Table.GetTableColumnsFor(t, sourceTable.ID()) expectedColumns := []expectedColumn{ {"id", sdk.DataTypeNumber}, {"col2", sdk.DataTypeVARCHAR}, @@ -272,7 +273,7 @@ func TestInt_Table(t *testing.T) { likeTable, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - likeTableColumns := getTableColumnsFor(t, client, likeTable.ID()) + likeTableColumns := testClientHelper().Table.GetTableColumnsFor(t, likeTable.ID()) assertColumns(t, expectedColumns, likeTableColumns) }) @@ -294,14 +295,14 @@ func TestInt_Table(t *testing.T) { WithMoment(sdk.CloneMomentAt)) // ensure that time travel is allowed (and revert if needed after the test) - revertParameter := updateAccountParameterTemporarily(t, client, sdk.AccountParameterDataRetentionTimeInDays, "1") + revertParameter := testClientHelper().Parameter.UpdateAccountParameterTemporarily(t, sdk.AccountParameterDataRetentionTimeInDays, "1") t.Cleanup(revertParameter) err := client.Tables.CreateClone(ctx, request) require.NoError(t, err) t.Cleanup(cleanupTableProvider(id)) - sourceTableColumns := getTableColumnsFor(t, client, sourceTable.ID()) + sourceTableColumns := testClientHelper().Table.GetTableColumnsFor(t, sourceTable.ID()) expectedColumns := []expectedColumn{ {"id", sdk.DataTypeNumber}, {"col2", sdk.DataTypeVARCHAR}, @@ -312,7 +313,7 @@ func TestInt_Table(t *testing.T) { cloneTable, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - cloneTableColumns := getTableColumnsFor(t, client, cloneTable.ID()) + cloneTableColumns := testClientHelper().Table.GetTableColumnsFor(t, cloneTable.ID()) assertColumns(t, expectedColumns, cloneTableColumns) }) @@ -482,7 +483,7 @@ func TestInt_Table(t *testing.T) { table, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - currentColumns := getTableColumnsFor(t, client, table.ID()) + currentColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_1", sdk.DataTypeVARCHAR}, {"COLUMN_2", sdk.DataTypeVARCHAR}, @@ -514,7 +515,7 @@ func TestInt_Table(t *testing.T) { table, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - currentColumns := getTableColumnsFor(t, client, table.ID()) + currentColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_3", sdk.DataTypeVARCHAR}, {"COLUMN_2", sdk.DataTypeVARCHAR}, @@ -630,7 +631,7 @@ func TestInt_Table(t *testing.T) { table, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - currentColumns := getTableColumnsFor(t, client, table.ID()) + currentColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_2", sdk.DataTypeVARCHAR}, } @@ -800,7 +801,7 @@ func TestInt_Table(t *testing.T) { table, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - currentColumns := getTableColumnsFor(t, client, table.ID()) + currentColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_1", sdk.DataTypeVARCHAR}, {"COLUMN_2", sdk.DataTypeVARCHAR}, @@ -830,7 +831,7 @@ func TestInt_Table(t *testing.T) { require.NoError(t, err) assert.Equal(t, table.Comment, "") - currentColumns := getTableColumnsFor(t, client, table.ID()) + currentColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_3", sdk.DataTypeVARCHAR}, {"COLUMN_2", sdk.DataTypeVARCHAR}, @@ -858,7 +859,7 @@ func TestInt_Table(t *testing.T) { table, err := client.Tables.ShowByID(ctx, id) require.NoError(t, err) - currentColumns := getTableColumnsFor(t, client, table.ID()) + currentColumns := testClientHelper().Table.GetTableColumnsFor(t, table.ID()) expectedColumns := []expectedColumn{ {"COLUMN_1", sdk.DataTypeVARCHAR}, } diff --git a/pkg/sdk/testint/views_gen_integration_test.go b/pkg/sdk/testint/views_gen_integration_test.go index e6638f5157..937773c217 100644 --- a/pkg/sdk/testint/views_gen_integration_test.go +++ b/pkg/sdk/testint/views_gen_integration_test.go @@ -136,7 +136,7 @@ func TestInt_Views(t *testing.T) { }) t.Run("create view: almost complete case", func(t *testing.T) { - rowAccessPolicyId, rowAccessPolicyCleanup := createRowAccessPolicy(t, client, testSchema(t)) + rowAccessPolicy, rowAccessPolicyCleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicyCleanup) tag, tagCleanup := testClientHelper().Tag.CreateTag(t) @@ -151,7 +151,7 @@ func TestInt_Views(t *testing.T) { }). WithCopyGrants(sdk.Bool(true)). WithComment(sdk.String("comment")). - WithRowAccessPolicy(sdk.NewViewRowAccessPolicyRequest(rowAccessPolicyId, []string{"column_with_comment"})). + WithRowAccessPolicy(sdk.NewViewRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"column_with_comment"})). WithTag([]sdk.TagAssociation{{ Name: tag.ID(), Value: "v2", @@ -162,9 +162,9 @@ func TestInt_Views(t *testing.T) { view := createViewWithRequest(t, request) assertViewWithOptions(t, view, id, true, "comment") - rowAccessPolicyReference, err := getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + rowAccessPolicyReference, err := testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView) require.NoError(t, err) - assert.Equal(t, rowAccessPolicyId.Name(), rowAccessPolicyReference.PolicyName) + assert.Equal(t, rowAccessPolicy.Name, rowAccessPolicyReference.PolicyName) assert.Equal(t, "ROW_ACCESS_POLICY", rowAccessPolicyReference.PolicyKind) assert.Equal(t, view.ID().Name(), rowAccessPolicyReference.RefEntityName) assert.Equal(t, "VIEW", rowAccessPolicyReference.RefEntityDomain) @@ -390,62 +390,62 @@ func TestInt_Views(t *testing.T) { }) t.Run("alter view: add and drop row access policies", func(t *testing.T) { - rowAccessPolicyId, rowAccessPolicyCleanup := createRowAccessPolicy(t, client, testSchema(t)) + rowAccessPolicy, rowAccessPolicyCleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicyCleanup) - rowAccessPolicy2Id, rowAccessPolicy2Cleanup := createRowAccessPolicy(t, client, testSchema(t)) + rowAccessPolicy2, rowAccessPolicy2Cleanup := testClientHelper().RowAccessPolicy.CreateRowAccessPolicy(t) t.Cleanup(rowAccessPolicy2Cleanup) view := createView(t) id := view.ID() // add policy - alterRequest := sdk.NewAlterViewRequest(id).WithAddRowAccessPolicy(sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicyId, []string{"ID"})) + alterRequest := sdk.NewAlterViewRequest(id).WithAddRowAccessPolicy(sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"ID"})) err := client.Views.Alter(ctx, alterRequest) require.NoError(t, err) - rowAccessPolicyReference, err := getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + rowAccessPolicyReference, err := testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView) require.NoError(t, err) - assert.Equal(t, rowAccessPolicyId.Name(), rowAccessPolicyReference.PolicyName) + assert.Equal(t, rowAccessPolicy.ID().Name(), rowAccessPolicyReference.PolicyName) assert.Equal(t, "ROW_ACCESS_POLICY", rowAccessPolicyReference.PolicyKind) assert.Equal(t, view.ID().Name(), rowAccessPolicyReference.RefEntityName) assert.Equal(t, "VIEW", rowAccessPolicyReference.RefEntityDomain) assert.Equal(t, "ACTIVE", rowAccessPolicyReference.PolicyStatus) // remove policy - alterRequest = sdk.NewAlterViewRequest(id).WithDropRowAccessPolicy(sdk.NewViewDropRowAccessPolicyRequest(rowAccessPolicyId)) + alterRequest = sdk.NewAlterViewRequest(id).WithDropRowAccessPolicy(sdk.NewViewDropRowAccessPolicyRequest(rowAccessPolicy.ID())) err = client.Views.Alter(ctx, alterRequest) require.NoError(t, err) - _, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + _, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView) require.Error(t, err, "no rows in result set") // add policy again - alterRequest = sdk.NewAlterViewRequest(id).WithAddRowAccessPolicy(sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicyId, []string{"ID"})) + alterRequest = sdk.NewAlterViewRequest(id).WithAddRowAccessPolicy(sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"ID"})) err = client.Views.Alter(ctx, alterRequest) require.NoError(t, err) - rowAccessPolicyReference, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + rowAccessPolicyReference, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView) require.NoError(t, err) - assert.Equal(t, rowAccessPolicyId.Name(), rowAccessPolicyReference.PolicyName) + assert.Equal(t, rowAccessPolicy.ID().Name(), rowAccessPolicyReference.PolicyName) // drop and add other policy simultaneously alterRequest = sdk.NewAlterViewRequest(id).WithDropAndAddRowAccessPolicy(sdk.NewViewDropAndAddRowAccessPolicyRequest( - *sdk.NewViewDropRowAccessPolicyRequest(rowAccessPolicyId), - *sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicy2Id, []string{"ID"}), + *sdk.NewViewDropRowAccessPolicyRequest(rowAccessPolicy.ID()), + *sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicy2.ID(), []string{"ID"}), )) err = client.Views.Alter(ctx, alterRequest) require.NoError(t, err) - rowAccessPolicyReference, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + rowAccessPolicyReference, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView) require.NoError(t, err) - assert.Equal(t, rowAccessPolicy2Id.Name(), rowAccessPolicyReference.PolicyName) + assert.Equal(t, rowAccessPolicy2.ID().Name(), rowAccessPolicyReference.PolicyName) // drop all policies alterRequest = sdk.NewAlterViewRequest(id).WithDropAllRowAccessPolicies(sdk.Bool(true)) err = client.Views.Alter(ctx, alterRequest) require.NoError(t, err) - _, err = getRowAccessPolicyFor(t, client, view.ID(), sdk.ObjectTypeView) + _, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView) require.Error(t, err, "no rows in result set") })