Skip to content

Commit

Permalink
changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jcieslak committed Nov 20, 2024
1 parent 5000619 commit b53359f
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 78 deletions.
9 changes: 9 additions & 0 deletions pkg/acceptance/helpers/information_schema_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ func (c *InformationSchemaClient) GetQueryTextByQueryId(t *testing.T, queryId st
require.NotNil(t, result[0]["QUERY_TEXT"])
return (*result[0]["QUERY_TEXT"]).(string)
}

func (c *InformationSchemaClient) GetQueryTagByQueryId(t *testing.T, queryId string) string {
t.Helper()
result, err := c.client().QueryUnsafe(context.Background(), fmt.Sprintf("SELECT QUERY_TAG FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY(RESULT_LIMIT => 20)) WHERE QUERY_ID = '%s'", queryId))
require.NoError(t, err)
require.Len(t, result, 1)
require.NotNil(t, result[0]["QUERY_TAG"])
return (*result[0]["QUERY_TAG"]).(string)
}
8 changes: 8 additions & 0 deletions pkg/acceptance/helpers/user_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ func (c *UserClient) Alter(t *testing.T, id sdk.AccountObjectIdentifier, opts *s
require.NoError(t, err)
}

func (c *UserClient) AlterCurrentUser(t *testing.T, opts *sdk.AlterUserOptions) {
t.Helper()
id, err := c.context.client.ContextFunctions.CurrentUser(context.Background())
require.NoError(t, err)
err = c.client().Alter(context.Background(), id, opts)
require.NoError(t, err)
}

func (c *UserClient) DropUserFunc(t *testing.T, id sdk.AccountObjectIdentifier) func() {
t.Helper()
ctx := context.Background()
Expand Down
33 changes: 24 additions & 9 deletions pkg/internal/tracking/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package tracking

import (
"context"
"errors"

Check failure on line 5 in pkg/internal/tracking/context.go

View workflow job for this annotation

GitHub Actions / reviewdog

[golangci] reported by reviewdog 🐶 File is not `gofumpt`-ed (gofumpt) Raw Output: pkg/internal/tracking/context.go:5: File is not `gofumpt`-ed (gofumpt) "errors"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources"
)

const (
ProviderVersion string = "v0.98.0" // TODO(SNOW-1814934): Currently hardcoded, make it computed
ProviderVersion string = "v0.99.0" // TODO(SNOW-1814934): Currently hardcoded, make it computed
MetadataPrefix string = "terraform_provider_usage_tracking"
)

type key int
type key struct{}

var metadataContextKey key

Expand All @@ -26,23 +27,37 @@ const (
)

type Metadata struct {
Version string `json:"version,omitempty"`
Resource resources.ResourceName `json:"resource,omitempty"`
Operation Operation `json:"operation,omitempty"`
Version string `json:"version,omitempty"`
Resource string `json:"resource,omitempty"`
Operation Operation `json:"operation,omitempty"`
}

func NewMetadata(version string, resourceName resources.ResourceName, operation Operation) Metadata {
func (m Metadata) validate() error {
errs := make([]error, 0)
if m.Version == "" {
errs = append(errs, errors.New("version for metadata should not be empty"))
}
if m.Resource == "" {
errs = append(errs, errors.New("resource name for metadata should not be empty"))
}
if m.Operation == "" {
errs = append(errs, errors.New("operation for metadata should not be empty"))
}
return errors.Join(errs...)
}

func NewMetadata(version string, resource resources.Resource, operation Operation) Metadata {
return Metadata{
Version: version,
Resource: resourceName,
Resource: resource.String(),
Operation: operation,
}
}

func NewVersionedMetadata(resourceName resources.ResourceName, operation Operation) Metadata {
func NewVersionedMetadata(resource resources.Resource, operation Operation) Metadata {
return Metadata{
Version: ProviderVersion,
Resource: resourceName,
Resource: resource.String(),
Operation: operation,
}
}
Expand Down
9 changes: 6 additions & 3 deletions pkg/internal/tracking/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"strings"
)

func AppendMetadataToSql(sql string, metadata Metadata) (string, error) {
func AppendMetadata(sql string, metadata Metadata) (string, error) {
bytes, err := json.Marshal(metadata)
if err != nil {
return "", fmt.Errorf("failed to marshal the metadata: %w", err)
Expand All @@ -15,14 +15,17 @@ func AppendMetadataToSql(sql string, metadata Metadata) (string, error) {
}
}

func ParseMetadataFromSql(sql string) (Metadata, error) {
func ParseMetadata(sql string) (Metadata, error) {
parts := strings.Split(sql, fmt.Sprintf("--%s", MetadataPrefix))
if len(parts) != 2 {
return Metadata{}, fmt.Errorf("failed to parse metadata from sql, incorrect number of parts, expected: 2, got: %d", len(parts))
}
var metadata Metadata
if err := json.Unmarshal([]byte(strings.TrimSpace(parts[1])), &metadata); err != nil {
return Metadata{}, fmt.Errorf("failed to unmarshal metadata from sql, err = %s", err)
return Metadata{}, fmt.Errorf("failed to unmarshal metadata from sql: %s, err = %s", sql, err)

Check failure on line 25 in pkg/internal/tracking/query.go

View workflow job for this annotation

GitHub Actions / reviewdog

[golangci] reported by reviewdog 🐶 non-wrapping format verb for fmt.Errorf. Use `%w` to format errors (errorlint) Raw Output: pkg/internal/tracking/query.go:25:93: non-wrapping format verb for fmt.Errorf. Use `%w` to format errors (errorlint) return Metadata{}, fmt.Errorf("failed to unmarshal metadata from sql: %s, err = %s", sql, err) ^
}
if err := metadata.validate(); err != nil {
return Metadata{}, err
}
return metadata, nil
}
25 changes: 14 additions & 11 deletions pkg/internal/tracking/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"testing"
)

func TestAppendMetadataToSql(t *testing.T) {
func TestAppendMetadata(t *testing.T) {
metadata := NewMetadata("123", resources.Account, CreateOperation)
sql := "SELECT 1"

Expand All @@ -17,33 +17,36 @@ func TestAppendMetadataToSql(t *testing.T) {

expectedSql := fmt.Sprintf("%s --%s %s", sql, MetadataPrefix, string(bytes))

newSql, err := AppendMetadataToSql(sql, metadata)
newSql, err := AppendMetadata(sql, metadata)
require.NoError(t, err)
require.Equal(t, expectedSql, newSql)
}

func TestParseMetadataFromSql(t *testing.T) {
func TestParseMetadata(t *testing.T) {
metadata := NewMetadata("123", resources.Account, CreateOperation)
sql, err := AppendMetadataToSql("SELECT 1", metadata)
bytes, err := json.Marshal(metadata)
require.NoError(t, err)
sql := fmt.Sprintf("SELECT 1 --%s %s", MetadataPrefix, string(bytes))

parsedMetadata, err := ParseMetadataFromSql(sql)
parsedMetadata, err := ParseMetadata(sql)
require.NoError(t, err)
require.Equal(t, metadata, parsedMetadata)
}

func TestParseInvalidMetadataKeysFromSql(t *testing.T) {
func TestParseInvalidMetadataKeys(t *testing.T) {
sql := fmt.Sprintf(`SELECT 1 --%s {"key": "value"}`, MetadataPrefix)

parsedMetadata, err := ParseMetadataFromSql(sql)
require.NoError(t, err)
parsedMetadata, err := ParseMetadata(sql)
require.ErrorContains(t, err, "version for metadata should not be empty")
require.ErrorContains(t, err, "resource name for metadata should not be empty")
require.ErrorContains(t, err, "operation for metadata should not be empty")
require.Equal(t, Metadata{}, parsedMetadata)
}

func TestParseInvalidMetadataJsonFromSql(t *testing.T) {
func TestParseInvalidMetadataJson(t *testing.T) {
sql := fmt.Sprintf(`SELECT 1 --%s "key": "value"`, MetadataPrefix)

parsedMetadata, err := ParseMetadataFromSql(sql)
parsedMetadata, err := ParseMetadata(sql)
require.ErrorContains(t, err, "failed to unmarshal metadata from sql")
require.Equal(t, Metadata{}, parsedMetadata)
}
Expand All @@ -55,7 +58,7 @@ func TestParseMetadataFromInvalidSqlCommentPrefix(t *testing.T) {
bytes, err := json.Marshal(metadata)
require.NoError(t, err)

parsedMetadata, err := ParseMetadataFromSql(fmt.Sprintf("%s --invalid_prefix %s", sql, string(bytes)))
parsedMetadata, err := ParseMetadata(fmt.Sprintf("%s --invalid_prefix %s", sql, string(bytes)))
require.ErrorContains(t, err, "failed to parse metadata from sql")
require.Equal(t, Metadata{}, parsedMetadata)
}
12 changes: 6 additions & 6 deletions pkg/resources/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,42 +105,42 @@ func ImportName[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | s
return []*schema.ResourceData{d}, nil
}

func CommonImportWrapper(resourceName resources.ResourceName, importImplementation schema.StateContextFunc) schema.StateContextFunc {
func TrackingImportWrapper(resourceName resources.Resource, importImplementation schema.StateContextFunc) schema.StateContextFunc {
return func(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) {
ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.ImportOperation))
return importImplementation(ctx, d, meta)
}
}

func CommonCreateWrapper(resourceName resources.ResourceName, createImplementation schema.CreateContextFunc) schema.CreateContextFunc {
func TrackingCreateWrapper(resourceName resources.Resource, createImplementation schema.CreateContextFunc) schema.CreateContextFunc {
return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.CreateOperation))
return createImplementation(ctx, d, meta)
}
}

func CommonReadWrapper(resourceName resources.ResourceName, readImplementation schema.ReadContextFunc) schema.ReadContextFunc {
func TrackingReadWrapper(resourceName resources.Resource, readImplementation schema.ReadContextFunc) schema.ReadContextFunc {
return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.ReadOperation))
return readImplementation(ctx, d, meta)
}
}

func CommonUpdateWrapper(resourceName resources.ResourceName, updateImplementation schema.UpdateContextFunc) schema.UpdateContextFunc {
func TrackingUpdateWrapper(resourceName resources.Resource, updateImplementation schema.UpdateContextFunc) schema.UpdateContextFunc {
return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.UpdateOperation))
return updateImplementation(ctx, d, meta)
}
}

func CommonDeleteWrapper(resourceName resources.ResourceName, deleteImplementation schema.DeleteContextFunc) schema.DeleteContextFunc {
func TrackingDeleteWrapper(resourceName resources.Resource, deleteImplementation schema.DeleteContextFunc) schema.DeleteContextFunc {
return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.DeleteOperation))
return deleteImplementation(ctx, d, meta)
}
}

func CommonCustomDiffWrapper(resourceName resources.ResourceName, customdiffImplementation schema.CustomizeDiffFunc) schema.CustomizeDiffFunc {
func TrackingCustomDiffWrapper(resourceName resources.Resource, customdiffImplementation schema.CustomizeDiffFunc) schema.CustomizeDiffFunc {
return func(ctx context.Context, diff *schema.ResourceDiff, meta any) error {
ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.CustomDiffOperation))
return customdiffImplementation(ctx, diff, meta)
Expand Down
14 changes: 6 additions & 8 deletions pkg/resources/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources"
"log"
"slices"
Expand Down Expand Up @@ -91,13 +90,13 @@ var schemaSchema = map[string]*schema.Schema{
// Schema returns a pointer to the resource representing a schema.
func Schema() *schema.Resource {
return &schema.Resource{
CreateContext: CommonCreateWrapper(resources.Schema, CreateContextSchema),
ReadContext: CommonReadWrapper(resources.Schema, ReadContextSchema(true)),
UpdateContext: CommonUpdateWrapper(resources.Schema, UpdateContextSchema),
DeleteContext: CommonDeleteWrapper(resources.Schema, DeleteContextSchema),
CreateContext: TrackingCreateWrapper(resources.Schema, CreateContextSchema),
ReadContext: TrackingReadWrapper(resources.Schema, ReadContextSchema(true)),
UpdateContext: TrackingUpdateWrapper(resources.Schema, UpdateContextSchema),
DeleteContext: TrackingDeleteWrapper(resources.Schema, DeleteContextSchema),
Description: "Resource used to manage schema objects. For more information, check [schema documentation](https://docs.snowflake.com/en/sql-reference/sql/create-schema).",

CustomizeDiff: CommonCustomDiffWrapper(resources.Schema, customdiff.All(
CustomizeDiff: TrackingCustomDiffWrapper(resources.Schema, customdiff.All(
ComputedIfAnyAttributeChanged(schemaSchema, ShowOutputAttributeName, "name", "comment", "with_managed_access", "is_transient"),
ComputedIfAnyAttributeChanged(schemaSchema, DescribeOutputAttributeName, "name"),
ComputedIfAnyAttributeChanged(schemaSchema, FullyQualifiedNameAttributeName, "name"),
Expand All @@ -108,7 +107,7 @@ func Schema() *schema.Resource {

Schema: collections.MergeMaps(schemaSchema, schemaParametersSchema),
Importer: &schema.ResourceImporter{
StateContext: CommonImportWrapper(resources.Schema, ImportSchema),
StateContext: TrackingImportWrapper(resources.Schema, ImportSchema),
},

SchemaVersion: 2,
Expand All @@ -131,7 +130,6 @@ func Schema() *schema.Resource {

func ImportSchema(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) {
log.Printf("[DEBUG] Starting schema import")
ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resources.Schema, tracking.ImportOperation))
client := meta.(*provider.Context).Client
id, err := sdk.ParseDatabaseObjectIdentifier(d.Id())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sdk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func (c *Client) queryOne(ctx context.Context, dest interface{}, sql string) err

func appendQueryMetadata(ctx context.Context, sql string) string {
if metadata, ok := tracking.FromContext(ctx); ok {
newSql, err := tracking.AppendMetadataToSql(sql, metadata)
newSql, err := tracking.AppendMetadata(sql, metadata)
if err != nil {
log.Printf("[ERROR] failed to append metadata tracking: %v\n", err)
return sql
Expand Down
77 changes: 44 additions & 33 deletions pkg/sdk/testint/basic_object_tracking_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,56 @@ func TestInt_ContextQueryTags(t *testing.T) {
client := testClient(t)
ctx := context.Background()

sessionId, err := client.ContextFunctions.CurrentSession(ctx)
require.NoError(t, err)
// set query_tag on user level
userQueryTag := "user query tag"
testClientHelper().User.AlterCurrentUser(t, &sdk.AlterUserOptions{
Set: &sdk.UserSet{
SessionParameters: &sdk.SessionParameters{
QueryTag: sdk.String(userQueryTag),
},
},
})
t.Cleanup(func() {
testClientHelper().User.AlterCurrentUser(t, &sdk.AlterUserOptions{
Unset: &sdk.UserUnset{
SessionParameters: &sdk.SessionParametersUnset{
QueryTag: sdk.Bool(true),
},
},
})
})
queryId := executeQueryAndReturnQueryId(t, context.Background(), client)
queryTagResult := testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId)
require.Equal(t, userQueryTag, queryTagResult)

queryTag := "some query tag"
// set query_tag on session level
sessionQueryTag := "session query tag"
require.NoError(t, client.Sessions.AlterSession(ctx, &sdk.AlterSessionOptions{
Set: &sdk.SessionSet{
SessionParameters: &sdk.SessionParameters{
QueryTag: sdk.String(queryTag),
QueryTag: sdk.String(sessionQueryTag),
},
},
}))
t.Cleanup(func() {
_, err = client.QueryUnsafe(ctx, "ALTER SESSION UNSET QUERY_TAG")
require.NoError(t, err)
require.NoError(t, client.Sessions.AlterSession(ctx, &sdk.AlterSessionOptions{
Unset: &sdk.SessionUnset{
SessionParametersUnset: &sdk.SessionParametersUnset{
QueryTag: sdk.Bool(true),
},
},
}))
})

queryId := executeQueryAndReturnQueryId(t, context.Background(), client)

result, err := client.QueryUnsafe(ctx, fmt.Sprintf("SELECT QUERY_ID, QUERY_TAG FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION(SESSION_ID => %s, RESULT_LIMIT => 2)) WHERE QUERY_ID = '%s'", sessionId, queryId))
require.NoError(t, err)
require.Len(t, result, 1)
require.Equal(t, queryId, *result[0]["QUERY_ID"])
require.Equal(t, queryTag, *result[0]["QUERY_TAG"])

newQueryTag := "some other query tag"
ctxWithQueryTag := gosnowflake.WithQueryTag(context.Background(), newQueryTag)
newQueryId := executeQueryAndReturnQueryId(t, ctxWithQueryTag, client)

result, err = client.QueryUnsafe(ctx, fmt.Sprintf("SELECT QUERY_ID, QUERY_TAG FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION(SESSION_ID => %s, RESULT_LIMIT => 2)) WHERE QUERY_ID = '%s'", sessionId, newQueryId))
require.NoError(t, err)
require.Len(t, result, 1)
require.Equal(t, newQueryId, *result[0]["QUERY_ID"])
require.Equal(t, newQueryTag, *result[0]["QUERY_TAG"])
queryId = executeQueryAndReturnQueryId(t, context.Background(), client)
queryTagResult = testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId)
require.Equal(t, sessionQueryTag, queryTagResult)

// set query_tag on query level
perQueryQueryTag := "per-query query tag"
ctxWithQueryTag := gosnowflake.WithQueryTag(context.Background(), perQueryQueryTag)
queryId = executeQueryAndReturnQueryId(t, ctxWithQueryTag, client)
queryTagResult = testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId)
require.Equal(t, perQueryQueryTag, queryTagResult)
}

func executeQueryAndReturnQueryId(t *testing.T, ctx context.Context, client *sdk.Client) string {

Check failure on line 72 in pkg/sdk/testint/basic_object_tracking_integration_test.go

View workflow job for this annotation

GitHub Actions / reviewdog

[golangci] reported by reviewdog 🐶 test helper function should start from t.Helper() (thelper) Raw Output: pkg/sdk/testint/basic_object_tracking_integration_test.go:72:6: test helper function should start from t.Helper() (thelper) func executeQueryAndReturnQueryId(t *testing.T, ctx context.Context, client *sdk.Client) string { ^
Expand All @@ -67,20 +84,14 @@ func TestInt_QueryComment(t *testing.T) {
client := testClient(t)
ctx := context.Background()

sessionId, err := client.ContextFunctions.CurrentSession(ctx)
require.NoError(t, err)

queryIdChan := make(chan string, 1)
metadata := `{"comment": "some comment"}`
_, err = client.QueryUnsafe(gosnowflake.WithQueryIDChan(ctx, queryIdChan), fmt.Sprintf(`SELECT 1; --%s`, metadata))
_, err := client.QueryUnsafe(gosnowflake.WithQueryIDChan(ctx, queryIdChan), fmt.Sprintf(`SELECT 1; --%s`, metadata))
require.NoError(t, err)
queryId := <-queryIdChan

result, err := client.QueryUnsafe(ctx, fmt.Sprintf("SELECT QUERY_ID, QUERY_TEXT FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION(SESSION_ID => %s, RESULT_LIMIT => 2)) WHERE QUERY_ID = '%s'", sessionId, queryId))
require.NoError(t, err)
require.Len(t, result, 1)
require.Equal(t, queryId, *result[0]["QUERY_ID"])
require.Equal(t, metadata, strings.Split((*result[0]["QUERY_TEXT"]).(string), "--")[1])
queryText := testClientHelper().InformationSchema.GetQueryTextByQueryId(t, queryId)
require.Equal(t, metadata, strings.Split(queryText, "--")[1])
}

func TestInt_AppName(t *testing.T) {
Expand Down
Loading

0 comments on commit b53359f

Please sign in to comment.