Skip to content

Commit

Permalink
fix: Fix database show by and resource logic (#3055)
Browse files Browse the repository at this point in the history
- Use the proper way for filtering results in ShowById for database
- Add known share issue to the list
- Mark row access policy as started
- Fix policy references implementation and test
- Fix TODO comments issues links
  • Loading branch information
sfc-gh-asawicki authored Sep 10, 2024
1 parent a65e564 commit 1887e55
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 54 deletions.
4 changes: 2 additions & 2 deletions pkg/acceptance/helpers/random/test_object_suffix.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/testenvs"
)

// TODO [SNOW-955520]: add generation tests
// TODO [SNOW-955520]: use the same fallback suffix for acceptance and integration tests (now two different ones are generated if the env is missing)
// TODO [SNOW-1356199]: add generation tests
// TODO [SNOW-1356199]: use the same fallback suffix for acceptance and integration tests (now two different ones are generated if the env is missing)
var (
AcceptanceTestsSuffix = acceptanceTestsSuffix()
IntegrationTestsSuffix = integrationTestsSuffix()
Expand Down
2 changes: 1 addition & 1 deletion pkg/resources/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func ReadDatabase(ctx context.Context, d *schema.ResourceData, meta any) diag.Di

database, err := client.Databases.ShowByID(ctx, id)
if err != nil {
if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) || errors.Is(err, sdk.ErrObjectNotFound) {
if errors.Is(err, sdk.ErrObjectNotFound) {
d.SetId("")
return diag.Diagnostics{
diag.Diagnostic{
Expand Down
9 changes: 3 additions & 6 deletions pkg/sdk/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strconv"
"strings"
"time"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections"
)

var (
Expand Down Expand Up @@ -791,12 +793,7 @@ func (v *databases) ShowByID(ctx context.Context, id AccountObjectIdentifier) (*
if err != nil {
return nil, err
}
for _, database := range databases {
if database.ID() == id {
return &database, nil
}
}
return nil, ErrObjectNotExistOrAuthorized
return collections.FindFirst(databases, func(r Database) bool { return r.Name == id.Name() })
}

type DatabaseDetails struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sdk/policy_references.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (row policyReferenceDBRow) convert() *PolicyReference {
if row.TagName.Valid {
policyReference.TagName = &row.TagName.String
}
if row.TagName.Valid {
if row.PolicyStatus.Valid {
policyReference.PolicyStatus = &row.PolicyStatus.String
}
return &policyReference
Expand Down
4 changes: 2 additions & 2 deletions pkg/sdk/sweepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ func SweepAfterAcceptanceTests(client *Client, suffix string) error {
return sweep(client, suffix)
}

// TODO [SNOW-955520]: move this to test code
// TODO [SNOW-955520]: use if exists/use method from helper for dropping
// TODO [SNOW-867247]: move this to test code
// TODO [SNOW-867247]: use if exists/use method from helper for dropping
// TODO [SNOW-867247]: sweep all missing account-level objects (like users, integrations, replication groups, network policies, ...)
// TODO [SNOW-867247]: extract sweepers to a separate dir
// TODO [SNOW-867247]: rework the sweepers (funcs -> objects)
Expand Down
16 changes: 8 additions & 8 deletions pkg/sdk/sweepers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
"github.com/stretchr/testify/assert"
)

// TODO [SNOW-955520]: move the sweepers outside of the sdk package
// TODO [SNOW-955520]: use test client helpers in sweepers?
// TODO [SNOW-867247]: move the sweepers outside of the sdk package
// TODO [SNOW-867247]: use test client helpers in sweepers?
func TestSweepAll(t *testing.T) {
_ = testenvs.GetOrSkipTest(t, testenvs.EnableSweep)
testenvs.AssertEnvSet(t, string(testenvs.TestObjectsSuffix))
Expand Down Expand Up @@ -85,28 +85,28 @@ func Test_Sweeper_NukeStaleObjects(t *testing.T) {
}
})

// TODO [SNOW-955520]:
// TODO [SNOW-867247]: unskip
t.Run("sweep databases", func(t *testing.T) {
t.Skipf("Used for manual sweeping; will be addressed during SNOW-955520")
t.Skipf("Used for manual sweeping; will be addressed during SNOW-867247")
for _, c := range allClients {
err := nukeDatabases(c, "")()
assert.NoError(t, err)
}
})

// TODO [SNOW-955520]:
// TODO [SNOW-867247]: unskip
t.Run("sweep warehouses", func(t *testing.T) {
t.Skipf("Used for manual sweeping; will be addressed during SNOW-955520")
t.Skipf("Used for manual sweeping; will be addressed during SNOW-867247")
for _, c := range allClients {
err := nukeWarehouses(c, "")()
assert.NoError(t, err)
}
})

// TODO [SNOW-955520]: nuke stale objects (e.g. created more than 2 weeks ago)
// TODO [SNOW-867247]: nuke stale objects (e.g. created more than 2 weeks ago)
}

// TODO [SNOW-955520]: generalize nuke methods (sweepers too)
// TODO [SNOW-867247]: generalize nuke methods (sweepers too)
// TODO [SNOW-1658402]: handle the ownership problem while handling the better role setup for tests
func nukeWarehouses(client *Client, prefix string) func() error {
protectedWarehouses := []string{
Expand Down
14 changes: 8 additions & 6 deletions pkg/sdk/testint/event_tables_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,19 @@ func TestInt_EventTables(t *testing.T) {
e, err := testClientHelper().PolicyReferences.GetPolicyReference(t, table.ID(), sdk.PolicyEntityDomainTable)
require.NoError(t, err)
assert.Equal(t, rowAccessPolicy.ID().Name(), e.PolicyName)
assert.Equal(t, "ROW_ACCESS_POLICY", e.PolicyKind)
assert.Equal(t, sdk.PolicyKindRowAccessPolicy, e.PolicyKind)
assert.Equal(t, table.ID().Name(), e.RefEntityName)
assert.Equal(t, "TABLE", e.RefEntityDomain)
assert.Equal(t, "ACTIVE", e.PolicyStatus)
assert.Equal(t, "ACTIVE", *e.PolicyStatus)

// remove policy
alterRequest = sdk.NewAlterEventTableRequest(table.ID()).WithDropRowAccessPolicy(sdk.NewEventTableDropRowAccessPolicyRequest(rowAccessPolicy.ID()))
err = client.EventTables.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, table.ID(), sdk.PolicyEntityDomainTable)
require.Error(t, err, "no rows in result set")
references, err := testClientHelper().PolicyReferences.GetPolicyReferences(t, table.ID(), sdk.PolicyEntityDomainTable)
require.NoError(t, err)
require.Empty(t, references)

// add policy again
alterRequest = sdk.NewAlterEventTableRequest(table.ID()).WithAddRowAccessPolicy(sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"id"}))
Expand Down Expand Up @@ -266,8 +267,9 @@ func TestInt_EventTables(t *testing.T) {
err = client.EventTables.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, table.ID(), sdk.PolicyEntityDomainView)
require.Error(t, err, "no rows in result set")
references, err = testClientHelper().PolicyReferences.GetPolicyReferences(t, table.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
require.Empty(t, references)
})
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/sdk/testint/materialized_views_gen_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ func TestInt_MaterializedViews(t *testing.T) {
rowAccessPolicyReference, err := testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
assert.Equal(t, rowAccessPolicy.Name, rowAccessPolicyReference.PolicyName)
assert.Equal(t, "ROW_ACCESS_POLICY", rowAccessPolicyReference.PolicyKind)
assert.Equal(t, sdk.PolicyKindRowAccessPolicy, rowAccessPolicyReference.PolicyKind)
assert.Equal(t, view.ID().Name(), rowAccessPolicyReference.RefEntityName)
assert.Equal(t, "MATERIALIZED_VIEW", rowAccessPolicyReference.RefEntityDomain)
assert.Equal(t, "ACTIVE", rowAccessPolicyReference.PolicyStatus)
assert.Equal(t, "ACTIVE", *rowAccessPolicyReference.PolicyStatus)
})

t.Run("drop materialized view: existing", func(t *testing.T) {
Expand Down
54 changes: 30 additions & 24 deletions pkg/sdk/testint/views_gen_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ func TestInt_Views(t *testing.T) {

assertPolicyReference := func(t *testing.T, policyRef sdk.PolicyReference,
policyId sdk.SchemaObjectIdentifier,
policyType string,
policyKind sdk.PolicyKind,
viewId sdk.SchemaObjectIdentifier,
refColumnName *string,
) {
t.Helper()
assert.Equal(t, policyId.Name(), policyRef.PolicyName)
assert.Equal(t, policyType, policyRef.PolicyKind)
assert.Equal(t, policyKind, policyRef.PolicyKind)
assert.Equal(t, viewId.Name(), policyRef.RefEntityName)
assert.Equal(t, "VIEW", policyRef.RefEntityDomain)
assert.Equal(t, "ACTIVE", policyRef.PolicyStatus)
assert.Equal(t, "ACTIVE", *policyRef.PolicyStatus)
if refColumnName != nil {
assert.NotNil(t, policyRef.RefColumnName)
assert.Equal(t, *refColumnName, *policyRef.RefColumnName)
Expand Down Expand Up @@ -204,9 +204,9 @@ func TestInt_Views(t *testing.T) {
return cmp.Compare(x.PolicyKind, y.PolicyKind)
})

assertPolicyReference(t, rowAccessPolicyReferences[0], aggregationPolicy, "AGGREGATION_POLICY", view.ID(), nil)
assertPolicyReference(t, rowAccessPolicyReferences[0], aggregationPolicy, sdk.PolicyKindAggregationPolicy, view.ID(), nil)

assertPolicyReference(t, rowAccessPolicyReferences[1], rowAccessPolicy.ID(), "ROW_ACCESS_POLICY", view.ID(), nil)
assertPolicyReference(t, rowAccessPolicyReferences[1], rowAccessPolicy.ID(), sdk.PolicyKindRowAccessPolicy, view.ID(), nil)
require.NotNil(t, rowAccessPolicyReferences[1].RefArgColumnNames)
refArgColumnNames := sdk.ParseCommaSeparatedStringArray(*rowAccessPolicyReferences[1].RefArgColumnNames, true)
assert.Len(t, refArgColumnNames, 1)
Expand Down Expand Up @@ -244,8 +244,8 @@ func TestInt_Views(t *testing.T) {
return cmp.Compare(x.PolicyKind, y.PolicyKind)
})

assertPolicyReference(t, rowAccessPolicyReferences[0], maskingPolicy.ID(), "MASKING_POLICY", view.ID(), sdk.Pointer("col1"))
assertPolicyReference(t, rowAccessPolicyReferences[1], projectionPolicy, "PROJECTION_POLICY", view.ID(), sdk.Pointer("col1"))
assertPolicyReference(t, rowAccessPolicyReferences[0], maskingPolicy.ID(), sdk.PolicyKindMaskingPolicy, view.ID(), sdk.Pointer("col1"))
assertPolicyReference(t, rowAccessPolicyReferences[1], projectionPolicy, sdk.PolicyKindProjectionPolicy, view.ID(), sdk.Pointer("col1"))
})

t.Run("drop view: existing", func(t *testing.T) {
Expand Down Expand Up @@ -408,16 +408,17 @@ func TestInt_Views(t *testing.T) {
require.NoError(t, err)
require.Len(t, policyReferences, 1)

assertPolicyReference(t, policyReferences[0], maskingPolicy.ID(), "MASKING_POLICY", view.ID(), sdk.Pointer("ID"))
assertPolicyReference(t, policyReferences[0], maskingPolicy.ID(), sdk.PolicyKindMaskingPolicy, view.ID(), sdk.Pointer("ID"))

alterRequest = sdk.NewAlterViewRequest(id).WithUnsetMaskingPolicyOnColumn(
*sdk.NewViewUnsetColumnMaskingPolicyRequest("ID"),
)
err = client.Views.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.Error(t, err, "no rows in result set")
references, err := testClientHelper().PolicyReferences.GetPolicyReferences(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
require.Empty(t, references)
})

t.Run("alter view: set and unset projection policy on column", func(t *testing.T) {
Expand All @@ -437,16 +438,17 @@ func TestInt_Views(t *testing.T) {
require.NoError(t, err)
require.Len(t, rowAccessPolicyReferences, 1)

assertPolicyReference(t, rowAccessPolicyReferences[0], projectionPolicy, "PROJECTION_POLICY", view.ID(), sdk.Pointer("ID"))
assertPolicyReference(t, rowAccessPolicyReferences[0], projectionPolicy, sdk.PolicyKindProjectionPolicy, view.ID(), sdk.Pointer("ID"))

alterRequest = sdk.NewAlterViewRequest(id).WithUnsetProjectionPolicyOnColumn(
*sdk.NewViewUnsetProjectionPolicyRequest("ID"),
)
err = client.Views.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.Error(t, err, "no rows in result set")
references, err := testClientHelper().PolicyReferences.GetPolicyReferences(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
require.Empty(t, references)
})

t.Run("alter view: set and unset tags on column", func(t *testing.T) {
Expand Down Expand Up @@ -506,15 +508,16 @@ func TestInt_Views(t *testing.T) {
rowAccessPolicyReference, err := testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)

assertPolicyReference(t, *rowAccessPolicyReference, rowAccessPolicy.ID(), "ROW_ACCESS_POLICY", view.ID(), nil)
assertPolicyReference(t, *rowAccessPolicyReference, rowAccessPolicy.ID(), sdk.PolicyKindRowAccessPolicy, view.ID(), nil)

// remove policy
alterRequest = sdk.NewAlterViewRequest(id).WithDropRowAccessPolicy(*sdk.NewViewDropRowAccessPolicyRequest(rowAccessPolicy.ID()))
err = client.Views.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.Error(t, err, "no rows in result set")
references, err := testClientHelper().PolicyReferences.GetPolicyReferences(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
require.Empty(t, references)

// add policy again
alterRequest = sdk.NewAlterViewRequest(id).WithAddRowAccessPolicy(*sdk.NewViewAddRowAccessPolicyRequest(rowAccessPolicy.ID(), []sdk.Column{{Value: "ID"}}))
Expand Down Expand Up @@ -542,8 +545,9 @@ func TestInt_Views(t *testing.T) {
err = client.Views.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.Error(t, err, "no rows in result set")
references, err = testClientHelper().PolicyReferences.GetPolicyReferences(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
require.Empty(t, references)
})

t.Run("alter view: add and drop data metrics", func(t *testing.T) {
Expand Down Expand Up @@ -615,8 +619,9 @@ func TestInt_Views(t *testing.T) {
err = client.Views.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.Error(t, err, "no rows in result set")
references, err := testClientHelper().PolicyReferences.GetPolicyReferences(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
require.Empty(t, references)
})

t.Run("alter view: set and unset aggregation policies", func(t *testing.T) {
Expand All @@ -637,7 +642,7 @@ func TestInt_Views(t *testing.T) {
require.NoError(t, err)
require.Len(t, rowAccessPolicyReferences, 1)

assertPolicyReference(t, rowAccessPolicyReferences[0], aggregationPolicy, "AGGREGATION_POLICY", view.ID(), nil)
assertPolicyReference(t, rowAccessPolicyReferences[0], aggregationPolicy, sdk.PolicyKindAggregationPolicy, view.ID(), nil)

// set policy with force
alterRequest = sdk.NewAlterViewRequest(id).WithSetAggregationPolicy(*sdk.NewViewSetAggregationPolicyRequest(aggregationPolicy2).
Expand All @@ -650,15 +655,16 @@ func TestInt_Views(t *testing.T) {
require.NoError(t, err)
require.Len(t, rowAccessPolicyReferences, 1)

assertPolicyReference(t, rowAccessPolicyReferences[0], aggregationPolicy2, "AGGREGATION_POLICY", view.ID(), nil)
assertPolicyReference(t, rowAccessPolicyReferences[0], aggregationPolicy2, sdk.PolicyKindAggregationPolicy, view.ID(), nil)

// remove policy
alterRequest = sdk.NewAlterViewRequest(id).WithUnsetAggregationPolicy(*sdk.NewViewUnsetAggregationPolicyRequest())
err = client.Views.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().PolicyReferences.GetPolicyReference(t, view.ID(), sdk.PolicyEntityDomainView)
require.Error(t, err, "no rows in result set")
references, err := testClientHelper().PolicyReferences.GetPolicyReferences(t, view.ID(), sdk.PolicyEntityDomainView)
require.NoError(t, err)
require.Empty(t, references)
})

t.Run("show view: default", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 1887e55

Please sign in to comment.