diff --git a/pkg/acceptance/helpers/information_schema_client.go b/pkg/acceptance/helpers/information_schema_client.go index cbecf91a7a..cfcc5a1e22 100644 --- a/pkg/acceptance/helpers/information_schema_client.go +++ b/pkg/acceptance/helpers/information_schema_client.go @@ -32,3 +32,12 @@ func (c *InformationSchemaClient) GetQueryTextByQueryId(t *testing.T, queryId st require.NotNil(t, result[0]["QUERY_TEXT"]) return (*result[0]["QUERY_TEXT"]).(string) } + +func (c *InformationSchemaClient) GetQueryTagByQueryId(t *testing.T, queryId string) string { + t.Helper() + result, err := c.client().QueryUnsafe(context.Background(), fmt.Sprintf("SELECT QUERY_TAG FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY(RESULT_LIMIT => 20)) WHERE QUERY_ID = '%s'", queryId)) + require.NoError(t, err) + require.Len(t, result, 1) + require.NotNil(t, result[0]["QUERY_TAG"]) + return (*result[0]["QUERY_TAG"]).(string) +} diff --git a/pkg/acceptance/helpers/user_client.go b/pkg/acceptance/helpers/user_client.go index c64afcf723..20461ae6e5 100644 --- a/pkg/acceptance/helpers/user_client.go +++ b/pkg/acceptance/helpers/user_client.go @@ -68,6 +68,14 @@ func (c *UserClient) Alter(t *testing.T, id sdk.AccountObjectIdentifier, opts *s require.NoError(t, err) } +func (c *UserClient) AlterCurrentUser(t *testing.T, opts *sdk.AlterUserOptions) { + t.Helper() + id, err := c.context.client.ContextFunctions.CurrentUser(context.Background()) + require.NoError(t, err) + err = c.client().Alter(context.Background(), id, opts) + require.NoError(t, err) +} + func (c *UserClient) DropUserFunc(t *testing.T, id sdk.AccountObjectIdentifier) func() { t.Helper() ctx := context.Background() diff --git a/pkg/internal/tracking/context.go b/pkg/internal/tracking/context.go index a3db0e4491..f0eb74d51c 100644 --- a/pkg/internal/tracking/context.go +++ b/pkg/internal/tracking/context.go @@ -2,15 +2,16 @@ package tracking import ( "context" + "errors" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" ) const ( - ProviderVersion string = "v0.98.0" // TODO(SNOW-1814934): Currently hardcoded, make it computed + ProviderVersion string = "v0.99.0" // TODO(SNOW-1814934): Currently hardcoded, make it computed MetadataPrefix string = "terraform_provider_usage_tracking" ) -type key int +type key struct{} var metadataContextKey key @@ -26,23 +27,37 @@ const ( ) type Metadata struct { - Version string `json:"version,omitempty"` - Resource resources.ResourceName `json:"resource,omitempty"` - Operation Operation `json:"operation,omitempty"` + Version string `json:"version,omitempty"` + Resource string `json:"resource,omitempty"` + Operation Operation `json:"operation,omitempty"` } -func NewMetadata(version string, resourceName resources.ResourceName, operation Operation) Metadata { +func (m Metadata) validate() error { + errs := make([]error, 0) + if m.Version == "" { + errs = append(errs, errors.New("version for metadata should not be empty")) + } + if m.Resource == "" { + errs = append(errs, errors.New("resource name for metadata should not be empty")) + } + if m.Operation == "" { + errs = append(errs, errors.New("operation for metadata should not be empty")) + } + return errors.Join(errs...) +} + +func NewMetadata(version string, resource resources.Resource, operation Operation) Metadata { return Metadata{ Version: version, - Resource: resourceName, + Resource: resource.String(), Operation: operation, } } -func NewVersionedMetadata(resourceName resources.ResourceName, operation Operation) Metadata { +func NewVersionedMetadata(resource resources.Resource, operation Operation) Metadata { return Metadata{ Version: ProviderVersion, - Resource: resourceName, + Resource: resource.String(), Operation: operation, } } diff --git a/pkg/internal/tracking/query.go b/pkg/internal/tracking/query.go index be7d3fe363..a3cb2a9a45 100644 --- a/pkg/internal/tracking/query.go +++ b/pkg/internal/tracking/query.go @@ -6,7 +6,7 @@ import ( "strings" ) -func AppendMetadataToSql(sql string, metadata Metadata) (string, error) { +func AppendMetadata(sql string, metadata Metadata) (string, error) { bytes, err := json.Marshal(metadata) if err != nil { return "", fmt.Errorf("failed to marshal the metadata: %w", err) @@ -15,14 +15,17 @@ func AppendMetadataToSql(sql string, metadata Metadata) (string, error) { } } -func ParseMetadataFromSql(sql string) (Metadata, error) { +func ParseMetadata(sql string) (Metadata, error) { parts := strings.Split(sql, fmt.Sprintf("--%s", MetadataPrefix)) if len(parts) != 2 { return Metadata{}, fmt.Errorf("failed to parse metadata from sql, incorrect number of parts, expected: 2, got: %d", len(parts)) } var metadata Metadata if err := json.Unmarshal([]byte(strings.TrimSpace(parts[1])), &metadata); err != nil { - return Metadata{}, fmt.Errorf("failed to unmarshal metadata from sql, err = %s", err) + return Metadata{}, fmt.Errorf("failed to unmarshal metadata from sql: %s, err = %s", sql, err) + } + if err := metadata.validate(); err != nil { + return Metadata{}, err } return metadata, nil } diff --git a/pkg/internal/tracking/query_test.go b/pkg/internal/tracking/query_test.go index 4058bb7fcd..077bb39725 100644 --- a/pkg/internal/tracking/query_test.go +++ b/pkg/internal/tracking/query_test.go @@ -8,7 +8,7 @@ import ( "testing" ) -func TestAppendMetadataToSql(t *testing.T) { +func TestAppendMetadata(t *testing.T) { metadata := NewMetadata("123", resources.Account, CreateOperation) sql := "SELECT 1" @@ -17,33 +17,36 @@ func TestAppendMetadataToSql(t *testing.T) { expectedSql := fmt.Sprintf("%s --%s %s", sql, MetadataPrefix, string(bytes)) - newSql, err := AppendMetadataToSql(sql, metadata) + newSql, err := AppendMetadata(sql, metadata) require.NoError(t, err) require.Equal(t, expectedSql, newSql) } -func TestParseMetadataFromSql(t *testing.T) { +func TestParseMetadata(t *testing.T) { metadata := NewMetadata("123", resources.Account, CreateOperation) - sql, err := AppendMetadataToSql("SELECT 1", metadata) + bytes, err := json.Marshal(metadata) require.NoError(t, err) + sql := fmt.Sprintf("SELECT 1 --%s %s", MetadataPrefix, string(bytes)) - parsedMetadata, err := ParseMetadataFromSql(sql) + parsedMetadata, err := ParseMetadata(sql) require.NoError(t, err) require.Equal(t, metadata, parsedMetadata) } -func TestParseInvalidMetadataKeysFromSql(t *testing.T) { +func TestParseInvalidMetadataKeys(t *testing.T) { sql := fmt.Sprintf(`SELECT 1 --%s {"key": "value"}`, MetadataPrefix) - parsedMetadata, err := ParseMetadataFromSql(sql) - require.NoError(t, err) + parsedMetadata, err := ParseMetadata(sql) + require.ErrorContains(t, err, "version for metadata should not be empty") + require.ErrorContains(t, err, "resource name for metadata should not be empty") + require.ErrorContains(t, err, "operation for metadata should not be empty") require.Equal(t, Metadata{}, parsedMetadata) } -func TestParseInvalidMetadataJsonFromSql(t *testing.T) { +func TestParseInvalidMetadataJson(t *testing.T) { sql := fmt.Sprintf(`SELECT 1 --%s "key": "value"`, MetadataPrefix) - parsedMetadata, err := ParseMetadataFromSql(sql) + parsedMetadata, err := ParseMetadata(sql) require.ErrorContains(t, err, "failed to unmarshal metadata from sql") require.Equal(t, Metadata{}, parsedMetadata) } @@ -55,7 +58,7 @@ func TestParseMetadataFromInvalidSqlCommentPrefix(t *testing.T) { bytes, err := json.Marshal(metadata) require.NoError(t, err) - parsedMetadata, err := ParseMetadataFromSql(fmt.Sprintf("%s --invalid_prefix %s", sql, string(bytes))) + parsedMetadata, err := ParseMetadata(fmt.Sprintf("%s --invalid_prefix %s", sql, string(bytes))) require.ErrorContains(t, err, "failed to parse metadata from sql") require.Equal(t, Metadata{}, parsedMetadata) } diff --git a/pkg/resources/common.go b/pkg/resources/common.go index d0cfd8b730..b33a0f433d 100644 --- a/pkg/resources/common.go +++ b/pkg/resources/common.go @@ -105,42 +105,42 @@ func ImportName[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | s return []*schema.ResourceData{d}, nil } -func CommonImportWrapper(resourceName resources.ResourceName, importImplementation schema.StateContextFunc) schema.StateContextFunc { +func TrackingImportWrapper(resourceName resources.Resource, importImplementation schema.StateContextFunc) schema.StateContextFunc { return func(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) { ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.ImportOperation)) return importImplementation(ctx, d, meta) } } -func CommonCreateWrapper(resourceName resources.ResourceName, createImplementation schema.CreateContextFunc) schema.CreateContextFunc { +func TrackingCreateWrapper(resourceName resources.Resource, createImplementation schema.CreateContextFunc) schema.CreateContextFunc { return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.CreateOperation)) return createImplementation(ctx, d, meta) } } -func CommonReadWrapper(resourceName resources.ResourceName, readImplementation schema.ReadContextFunc) schema.ReadContextFunc { +func TrackingReadWrapper(resourceName resources.Resource, readImplementation schema.ReadContextFunc) schema.ReadContextFunc { return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.ReadOperation)) return readImplementation(ctx, d, meta) } } -func CommonUpdateWrapper(resourceName resources.ResourceName, updateImplementation schema.UpdateContextFunc) schema.UpdateContextFunc { +func TrackingUpdateWrapper(resourceName resources.Resource, updateImplementation schema.UpdateContextFunc) schema.UpdateContextFunc { return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.UpdateOperation)) return updateImplementation(ctx, d, meta) } } -func CommonDeleteWrapper(resourceName resources.ResourceName, deleteImplementation schema.DeleteContextFunc) schema.DeleteContextFunc { +func TrackingDeleteWrapper(resourceName resources.Resource, deleteImplementation schema.DeleteContextFunc) schema.DeleteContextFunc { return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.DeleteOperation)) return deleteImplementation(ctx, d, meta) } } -func CommonCustomDiffWrapper(resourceName resources.ResourceName, customdiffImplementation schema.CustomizeDiffFunc) schema.CustomizeDiffFunc { +func TrackingCustomDiffWrapper(resourceName resources.Resource, customdiffImplementation schema.CustomizeDiffFunc) schema.CustomizeDiffFunc { return func(ctx context.Context, diff *schema.ResourceDiff, meta any) error { ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.CustomDiffOperation)) return customdiffImplementation(ctx, diff, meta) diff --git a/pkg/resources/schema.go b/pkg/resources/schema.go index f519d0534e..f485428333 100644 --- a/pkg/resources/schema.go +++ b/pkg/resources/schema.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "log" "slices" @@ -91,13 +90,13 @@ var schemaSchema = map[string]*schema.Schema{ // Schema returns a pointer to the resource representing a schema. func Schema() *schema.Resource { return &schema.Resource{ - CreateContext: CommonCreateWrapper(resources.Schema, CreateContextSchema), - ReadContext: CommonReadWrapper(resources.Schema, ReadContextSchema(true)), - UpdateContext: CommonUpdateWrapper(resources.Schema, UpdateContextSchema), - DeleteContext: CommonDeleteWrapper(resources.Schema, DeleteContextSchema), + CreateContext: TrackingCreateWrapper(resources.Schema, CreateContextSchema), + ReadContext: TrackingReadWrapper(resources.Schema, ReadContextSchema(true)), + UpdateContext: TrackingUpdateWrapper(resources.Schema, UpdateContextSchema), + DeleteContext: TrackingDeleteWrapper(resources.Schema, DeleteContextSchema), Description: "Resource used to manage schema objects. For more information, check [schema documentation](https://docs.snowflake.com/en/sql-reference/sql/create-schema).", - CustomizeDiff: CommonCustomDiffWrapper(resources.Schema, customdiff.All( + CustomizeDiff: TrackingCustomDiffWrapper(resources.Schema, customdiff.All( ComputedIfAnyAttributeChanged(schemaSchema, ShowOutputAttributeName, "name", "comment", "with_managed_access", "is_transient"), ComputedIfAnyAttributeChanged(schemaSchema, DescribeOutputAttributeName, "name"), ComputedIfAnyAttributeChanged(schemaSchema, FullyQualifiedNameAttributeName, "name"), @@ -108,7 +107,7 @@ func Schema() *schema.Resource { Schema: collections.MergeMaps(schemaSchema, schemaParametersSchema), Importer: &schema.ResourceImporter{ - StateContext: CommonImportWrapper(resources.Schema, ImportSchema), + StateContext: TrackingImportWrapper(resources.Schema, ImportSchema), }, SchemaVersion: 2, @@ -131,7 +130,6 @@ func Schema() *schema.Resource { func ImportSchema(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) { log.Printf("[DEBUG] Starting schema import") - ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resources.Schema, tracking.ImportOperation)) client := meta.(*provider.Context).Client id, err := sdk.ParseDatabaseObjectIdentifier(d.Id()) if err != nil { diff --git a/pkg/sdk/client.go b/pkg/sdk/client.go index 6aeab50663..b7796b9ffe 100644 --- a/pkg/sdk/client.go +++ b/pkg/sdk/client.go @@ -308,7 +308,7 @@ func (c *Client) queryOne(ctx context.Context, dest interface{}, sql string) err func appendQueryMetadata(ctx context.Context, sql string) string { if metadata, ok := tracking.FromContext(ctx); ok { - newSql, err := tracking.AppendMetadataToSql(sql, metadata) + newSql, err := tracking.AppendMetadata(sql, metadata) if err != nil { log.Printf("[ERROR] failed to append metadata tracking: %v\n", err) return sql diff --git a/pkg/sdk/testint/basic_object_tracking_integration_test.go b/pkg/sdk/testint/basic_object_tracking_integration_test.go index da3d1e2435..9e70f0813d 100644 --- a/pkg/sdk/testint/basic_object_tracking_integration_test.go +++ b/pkg/sdk/testint/basic_object_tracking_integration_test.go @@ -17,39 +17,56 @@ func TestInt_ContextQueryTags(t *testing.T) { client := testClient(t) ctx := context.Background() - sessionId, err := client.ContextFunctions.CurrentSession(ctx) - require.NoError(t, err) + // set query_tag on user level + userQueryTag := "user query tag" + testClientHelper().User.AlterCurrentUser(t, &sdk.AlterUserOptions{ + Set: &sdk.UserSet{ + SessionParameters: &sdk.SessionParameters{ + QueryTag: sdk.String(userQueryTag), + }, + }, + }) + t.Cleanup(func() { + testClientHelper().User.AlterCurrentUser(t, &sdk.AlterUserOptions{ + Unset: &sdk.UserUnset{ + SessionParameters: &sdk.SessionParametersUnset{ + QueryTag: sdk.Bool(true), + }, + }, + }) + }) + queryId := executeQueryAndReturnQueryId(t, context.Background(), client) + queryTagResult := testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId) + require.Equal(t, userQueryTag, queryTagResult) - queryTag := "some query tag" + // set query_tag on session level + sessionQueryTag := "session query tag" require.NoError(t, client.Sessions.AlterSession(ctx, &sdk.AlterSessionOptions{ Set: &sdk.SessionSet{ SessionParameters: &sdk.SessionParameters{ - QueryTag: sdk.String(queryTag), + QueryTag: sdk.String(sessionQueryTag), }, }, })) t.Cleanup(func() { - _, err = client.QueryUnsafe(ctx, "ALTER SESSION UNSET QUERY_TAG") - require.NoError(t, err) + require.NoError(t, client.Sessions.AlterSession(ctx, &sdk.AlterSessionOptions{ + Unset: &sdk.SessionUnset{ + SessionParametersUnset: &sdk.SessionParametersUnset{ + QueryTag: sdk.Bool(true), + }, + }, + })) }) - - queryId := executeQueryAndReturnQueryId(t, context.Background(), client) - - result, err := client.QueryUnsafe(ctx, fmt.Sprintf("SELECT QUERY_ID, QUERY_TAG FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION(SESSION_ID => %s, RESULT_LIMIT => 2)) WHERE QUERY_ID = '%s'", sessionId, queryId)) - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, queryId, *result[0]["QUERY_ID"]) - require.Equal(t, queryTag, *result[0]["QUERY_TAG"]) - - newQueryTag := "some other query tag" - ctxWithQueryTag := gosnowflake.WithQueryTag(context.Background(), newQueryTag) - newQueryId := executeQueryAndReturnQueryId(t, ctxWithQueryTag, client) - - result, err = client.QueryUnsafe(ctx, fmt.Sprintf("SELECT QUERY_ID, QUERY_TAG FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION(SESSION_ID => %s, RESULT_LIMIT => 2)) WHERE QUERY_ID = '%s'", sessionId, newQueryId)) - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, newQueryId, *result[0]["QUERY_ID"]) - require.Equal(t, newQueryTag, *result[0]["QUERY_TAG"]) + queryId = executeQueryAndReturnQueryId(t, context.Background(), client) + queryTagResult = testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId) + require.Equal(t, sessionQueryTag, queryTagResult) + + // set query_tag on query level + perQueryQueryTag := "per-query query tag" + ctxWithQueryTag := gosnowflake.WithQueryTag(context.Background(), perQueryQueryTag) + queryId = executeQueryAndReturnQueryId(t, ctxWithQueryTag, client) + queryTagResult = testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId) + require.Equal(t, perQueryQueryTag, queryTagResult) } func executeQueryAndReturnQueryId(t *testing.T, ctx context.Context, client *sdk.Client) string { @@ -67,20 +84,14 @@ func TestInt_QueryComment(t *testing.T) { client := testClient(t) ctx := context.Background() - sessionId, err := client.ContextFunctions.CurrentSession(ctx) - require.NoError(t, err) - queryIdChan := make(chan string, 1) metadata := `{"comment": "some comment"}` - _, err = client.QueryUnsafe(gosnowflake.WithQueryIDChan(ctx, queryIdChan), fmt.Sprintf(`SELECT 1; --%s`, metadata)) + _, err := client.QueryUnsafe(gosnowflake.WithQueryIDChan(ctx, queryIdChan), fmt.Sprintf(`SELECT 1; --%s`, metadata)) require.NoError(t, err) queryId := <-queryIdChan - result, err := client.QueryUnsafe(ctx, fmt.Sprintf("SELECT QUERY_ID, QUERY_TEXT FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION(SESSION_ID => %s, RESULT_LIMIT => 2)) WHERE QUERY_ID = '%s'", sessionId, queryId)) - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, queryId, *result[0]["QUERY_ID"]) - require.Equal(t, metadata, strings.Split((*result[0]["QUERY_TEXT"]).(string), "--")[1]) + queryText := testClientHelper().InformationSchema.GetQueryTextByQueryId(t, queryId) + require.Equal(t, metadata, strings.Split(queryText, "--")[1]) } func TestInt_AppName(t *testing.T) { diff --git a/pkg/sdk/testint/client_integration_test.go b/pkg/sdk/testint/client_integration_test.go index 8e22e58e2d..64fd237ed2 100644 --- a/pkg/sdk/testint/client_integration_test.go +++ b/pkg/sdk/testint/client_integration_test.go @@ -3,6 +3,7 @@ package testint import ( "context" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/snowflakedb/gosnowflake" "github.com/stretchr/testify/require" "testing" @@ -10,17 +11,12 @@ import ( func TestInt_Client_AdditionalMetadata(t *testing.T) { client := testClient(t) - - metadata := tracking.Metadata{ - Version: "v1.0.0", - Resource: "database", - Operation: tracking.CreateOperation, - } + metadata := tracking.NewMetadata("v1.13.1002-rc-test", resources.Database, tracking.CreateOperation) assertQueryMetadata := func(t *testing.T, queryId string) { t.Helper() queryText := testClientHelper().InformationSchema.GetQueryTextByQueryId(t, queryId) - parsedMetadata, err := tracking.ParseMetadataFromSql(queryText) + parsedMetadata, err := tracking.ParseMetadata(queryText) require.NoError(t, err) require.Equal(t, metadata, parsedMetadata) }