diff --git a/pkg/acceptance/check_destroy.go b/pkg/acceptance/check_destroy.go index 0ab27df968..617c90e5d2 100644 --- a/pkg/acceptance/check_destroy.go +++ b/pkg/acceptance/check_destroy.go @@ -27,7 +27,10 @@ func CheckDestroy(t *testing.T, resource resources.Resource) func(*terraform.Sta } t.Logf("found resource %s in state", resource) ctx := context.Background() - id := decodeSnowflakeId(rs, resource) + id, err := decodeSnowflakeId(rs, resource) + if err != nil { + return err + } if id == nil { return fmt.Errorf("could not get the id of %s", resource) } @@ -45,16 +48,16 @@ func CheckDestroy(t *testing.T, resource resources.Resource) func(*terraform.Sta } } -func decodeSnowflakeId(rs *terraform.ResourceState, resource resources.Resource) sdk.ObjectIdentifier { +func decodeSnowflakeId(rs *terraform.ResourceState, resource resources.Resource) (sdk.ObjectIdentifier, error) { switch resource { case resources.ExternalFunction: - return sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(rs.Primary.ID) + return sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(rs.Primary.ID), nil case resources.Function: - return sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(rs.Primary.ID) + return sdk.ParseSchemaObjectIdentifierWithArguments(rs.Primary.ID) case resources.Procedure: - return sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(rs.Primary.ID) + return sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(rs.Primary.ID), nil default: - return helpers.DecodeSnowflakeID(rs.Primary.ID) + return helpers.DecodeSnowflakeID(rs.Primary.ID), nil } } @@ -213,7 +216,7 @@ var showByIdFunctions = map[resources.Resource]showByIdFunc{ }, } -func runShowById[T any, U sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier](ctx context.Context, id sdk.ObjectIdentifier, show func(ctx context.Context, id U) (T, error)) error { +func runShowById[T any, U sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier | sdk.SchemaObjectIdentifierWithArguments](ctx context.Context, id sdk.ObjectIdentifier, show func(ctx context.Context, id U) (T, error)) error { idCast, err := asId[U](id) if err != nil { return err @@ -222,7 +225,7 @@ func runShowById[T any, U sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdenti return err } -func asId[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier](id sdk.ObjectIdentifier) (*T, error) { +func asId[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier | sdk.SchemaObjectIdentifierWithArguments](id sdk.ObjectIdentifier) (*T, error) { if idCast, ok := id.(T); !ok { return nil, fmt.Errorf("expected %s identifier type, but got: %T", reflect.TypeOf(new(T)).Elem().Name(), id) } else { diff --git a/pkg/acceptance/helpers/ids_generator.go b/pkg/acceptance/helpers/ids_generator.go index ec74e0d8aa..42e247e6d5 100644 --- a/pkg/acceptance/helpers/ids_generator.go +++ b/pkg/acceptance/helpers/ids_generator.go @@ -77,14 +77,26 @@ func (c *IdsGenerator) RandomSchemaObjectIdentifierWithPrefix(prefix string) sdk return sdk.NewSchemaObjectIdentifierInSchema(c.SchemaId(), c.AlphaWithPrefix(prefix)) } -func (c *IdsGenerator) RandomSchemaObjectIdentifierWithArguments(arguments []sdk.DataType) sdk.SchemaObjectIdentifier { - return sdk.NewSchemaObjectIdentifierWithArguments(c.SchemaId().DatabaseName(), c.SchemaId().Name(), c.Alpha(), arguments) -} - func (c *IdsGenerator) RandomSchemaObjectIdentifierInSchema(schemaId sdk.DatabaseObjectIdentifier) sdk.SchemaObjectIdentifier { return sdk.NewSchemaObjectIdentifierInSchema(schemaId, c.Alpha()) } +func (c *IdsGenerator) RandomSchemaObjectIdentifierWithArgumentsOld(arguments ...sdk.DataType) sdk.SchemaObjectIdentifier { + return sdk.NewSchemaObjectIdentifierWithArgumentsOld(c.SchemaId().DatabaseName(), c.SchemaId().Name(), c.Alpha(), arguments) +} + +func (c *IdsGenerator) NewSchemaObjectIdentifierWithArguments(name string, arguments ...sdk.DataType) sdk.SchemaObjectIdentifierWithArguments { + return sdk.NewSchemaObjectIdentifierWithArguments(c.SchemaId().DatabaseName(), c.SchemaId().Name(), name, arguments...) +} + +func (c *IdsGenerator) NewSchemaObjectIdentifierWithArgumentsInSchema(name string, schemaId sdk.DatabaseObjectIdentifier, argumentDataTypes ...sdk.DataType) sdk.SchemaObjectIdentifierWithArguments { + return sdk.NewSchemaObjectIdentifierWithArgumentsInSchema(schemaId, name, argumentDataTypes...) +} + +func (c *IdsGenerator) RandomSchemaObjectIdentifierWithArguments(arguments ...sdk.DataType) sdk.SchemaObjectIdentifierWithArguments { + return sdk.NewSchemaObjectIdentifierWithArguments(c.SchemaId().DatabaseName(), c.SchemaId().Name(), c.Alpha(), arguments...) +} + func (c *IdsGenerator) Alpha() string { return c.AlphaN(6) } diff --git a/pkg/datasources/functions.go b/pkg/datasources/functions.go index 54aa7b2760..b54213442c 100644 --- a/pkg/datasources/functions.go +++ b/pkg/datasources/functions.go @@ -75,7 +75,7 @@ func ReadContextFunctions(ctx context.Context, d *schema.ResourceData, meta inte schemaName := d.Get("schema").(string) request := sdk.NewShowFunctionRequest() - request.WithIn(&sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)}) + request.WithIn(sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)}) functions, err := client.Functions.Show(ctx, request) if err != nil { id := d.Id() @@ -92,7 +92,8 @@ func ReadContextFunctions(ctx context.Context, d *schema.ResourceData, meta inte entities := []map[string]interface{}{} for _, item := range functions { - signature, err := parseArguments(item.Arguments) + // TODO(SNOW-1596962): Create argument parsing function that takes argument names into consideration. + signature, err := parseArguments(item.ArgumentsRaw) if err != nil { return diag.FromErr(err) } diff --git a/pkg/resources/external_function.go b/pkg/resources/external_function.go index 5bdff6e3b3..e57e43ce51 100644 --- a/pkg/resources/external_function.go +++ b/pkg/resources/external_function.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" @@ -318,7 +317,7 @@ func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, for _, item := range args { argTypes = append(argTypes, item.ArgDataType) } - sid := sdk.NewSchemaObjectIdentifierWithArguments(database, schemaName, name, argTypes) + sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schemaName, name, argTypes) d.SetId(sid.FullyQualifiedName()) return ReadContextExternalFunction(ctx, d, meta) } @@ -476,7 +475,7 @@ func UpdateContextExternalFunction(ctx context.Context, d *schema.ResourceData, client := meta.(*provider.Context).Client id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - req := sdk.NewAlterFunctionRequest(id.WithoutArguments(), id.Arguments()) + req := sdk.NewAlterFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) if d.HasChange("comment") { _, new := d.GetChange("comment") if new == "" { @@ -496,7 +495,7 @@ func DeleteContextExternalFunction(ctx context.Context, d *schema.ResourceData, client := meta.(*provider.Context).Client id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - req := sdk.NewDropFunctionRequest(id.WithoutArguments(), id.Arguments()) + req := sdk.NewDropFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) if err := client.Functions.Drop(ctx, req); err != nil { return diag.FromErr(err) } diff --git a/pkg/resources/external_function_acceptance_test.go b/pkg/resources/external_function_acceptance_test.go index 66513f390d..d147071fd2 100644 --- a/pkg/resources/external_function_acceptance_test.go +++ b/pkg/resources/external_function_acceptance_test.go @@ -5,7 +5,6 @@ import ( "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/config" @@ -226,7 +225,7 @@ func TestAcc_ExternalFunction_complete(t *testing.T) { } func TestAcc_ExternalFunction_migrateFromVersion085(t *testing.T) { - id := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArguments([]sdk.DataType{sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR}) + id := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR) name := id.Name() resourceName := "snowflake_external_function.f" @@ -450,30 +449,30 @@ func externalFunctionConfigWithReturnNullAllowed(database string, schema string, return fmt.Sprintf(` resource "snowflake_api_integration" "test_api_int" { - name = "%[3]s" - api_provider = "aws_api_gateway" - api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" - api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] - enabled = true + name = "%[3]s" + api_provider = "aws_api_gateway" + api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" + api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] + enabled = true } resource "snowflake_external_function" "f" { - name = "%[3]s" - database = "%[1]s" - schema = "%[2]s" - arg { - name = "ARG1" - type = "VARCHAR" - } - arg { - name = "ARG2" - type = "VARCHAR" - } - return_type = "VARIANT" - return_behavior = "IMMUTABLE" - api_integration = snowflake_api_integration.test_api_int.name - url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" - %[4]s + name = "%[3]s" + database = "%[1]s" + schema = "%[2]s" + arg { + name = "ARG1" + type = "VARCHAR" + } + arg { + name = "ARG2" + type = "VARCHAR" + } + return_type = "VARIANT" + return_behavior = "IMMUTABLE" + api_integration = snowflake_api_integration.test_api_int.name + url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" + %[4]s } `, database, schema, name, returnNullAllowedText) @@ -482,46 +481,46 @@ resource "snowflake_external_function" "f" { func externalFunctionConfigIssue2528(database string, schema string, name string, schema2 string) string { return fmt.Sprintf(` resource "snowflake_api_integration" "test_api_int" { - name = "%[3]s" - api_provider = "aws_api_gateway" - api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" - api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] - enabled = true + name = "%[3]s" + api_provider = "aws_api_gateway" + api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" + api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] + enabled = true } resource "snowflake_schema" "s2" { - database = "%[1]s" - name = "%[4]s" + database = "%[1]s" + name = "%[4]s" } resource "snowflake_external_function" "f" { - name = "%[3]s" - database = "%[1]s" - schema = "%[2]s" - arg { - name = "SNS_NOTIF" - type = "OBJECT" - } - return_type = "VARIANT" - return_behavior = "VOLATILE" - api_integration = snowflake_api_integration.test_api_int.name - url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" + name = "%[3]s" + database = "%[1]s" + schema = "%[2]s" + arg { + name = "SNS_NOTIF" + type = "OBJECT" + } + return_type = "VARIANT" + return_behavior = "VOLATILE" + api_integration = snowflake_api_integration.test_api_int.name + url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" } resource "snowflake_external_function" "f2" { - depends_on = [snowflake_schema.s2] - - name = "%[3]s" - database = "%[1]s" - schema = "%[4]s" - arg { - name = "SNS_NOTIF" - type = "OBJECT" - } - return_type = "VARIANT" - return_behavior = "VOLATILE" - api_integration = snowflake_api_integration.test_api_int.name - url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" + depends_on = [snowflake_schema.s2] + + name = "%[3]s" + database = "%[1]s" + schema = "%[4]s" + arg { + name = "SNS_NOTIF" + type = "OBJECT" + } + return_type = "VARIANT" + return_behavior = "VOLATILE" + api_integration = snowflake_api_integration.test_api_int.name + url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" } `, database, schema, name, schema2) } @@ -529,33 +528,33 @@ resource "snowflake_external_function" "f2" { func externalFunctionConfigIssueCurlyHeader(id sdk.SchemaObjectIdentifier) string { return fmt.Sprintf(` resource "snowflake_api_integration" "test_api_int" { - name = "%[3]s" - api_provider = "aws_api_gateway" - api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" - api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] - enabled = true + name = "%[3]s" + api_provider = "aws_api_gateway" + api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" + api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] + enabled = true } resource "snowflake_external_function" "f" { - name = "%[3]s" - database = "%[1]s" - schema = "%[2]s" - arg { - name = "ARG1" - type = "VARCHAR" - } - arg { - name = "ARG2" - type = "VARCHAR" - } - header { + name = "%[3]s" + database = "%[1]s" + schema = "%[2]s" + arg { + name = "ARG1" + type = "VARCHAR" + } + arg { + name = "ARG2" + type = "VARCHAR" + } + header { name = "name" value = "{0}" - } - return_type = "VARIANT" - return_behavior = "IMMUTABLE" - api_integration = snowflake_api_integration.test_api_int.name - url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" + } + return_type = "VARIANT" + return_behavior = "IMMUTABLE" + api_integration = snowflake_api_integration.test_api_int.name + url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" } `, id.DatabaseName(), id.SchemaName(), id.Name()) diff --git a/pkg/resources/external_function_state_upgraders.go b/pkg/resources/external_function_state_upgraders.go index 55d2b95868..aba74585aa 100644 --- a/pkg/resources/external_function_state_upgraders.go +++ b/pkg/resources/external_function_state_upgraders.go @@ -60,7 +60,7 @@ func v085ExternalFunctionStateUpgrader(ctx context.Context, rawState map[string] } } - schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(parsedV085ExternalFunctionId.DatabaseName, parsedV085ExternalFunctionId.SchemaName, parsedV085ExternalFunctionId.ExternalFunctionName, argDataTypes) + schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArgumentsOld(parsedV085ExternalFunctionId.DatabaseName, parsedV085ExternalFunctionId.SchemaName, parsedV085ExternalFunctionId.ExternalFunctionName, argDataTypes) rawState["id"] = schemaObjectIdentifierWithArguments.FullyQualifiedName() oldDatabase := rawState["database"].(string) diff --git a/pkg/resources/function.go b/pkg/resources/function.go index ea6da70f53..f222c9b0a7 100644 --- a/pkg/resources/function.go +++ b/pkg/resources/function.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/go-cty/cty" @@ -50,6 +49,7 @@ var functionSchema = map[string]*schema.Schema{ }, Description: "The argument name", }, + // TODO(SNOW-1596962): Fully support VECTOR data type sdk.ParseFunctionArgumentsFromString could be a base for another function that takes argument names into consideration. "type": { Type: schema.TypeString, Required: true, @@ -159,7 +159,6 @@ var functionSchema = map[string]*schema.Schema{ }, } -// Function returns a pointer to the resource representing a stored function. func Function() *schema.Resource { return &schema.Resource{ SchemaVersion: 1, @@ -225,11 +224,11 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf // create request with required request := sdk.NewCreateForJavaFunctionRequest(id, *returns, handler) functionDefinition := d.Get("statement").(string) - request.WithFunctionDefinition(sdk.String(functionDefinition)) + request.WithFunctionDefinition(functionDefinition) // Set optionals if v, ok := d.GetOk("is_secure"); ok { - request.WithSecure(sdk.Bool(v.(bool))) + request.WithSecure(v.(bool)) } arguments, diags := parseFunctionArguments(d) if diags != nil { @@ -239,16 +238,16 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf request.WithArguments(arguments) } if v, ok := d.GetOk("null_input_behavior"); ok { - request.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) + request.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) } if v, ok := d.GetOk("return_behavior"); ok { - request.WithReturnResultsBehavior(sdk.Pointer(sdk.ReturnResultsBehavior(v.(string)))) + request.WithReturnResultsBehavior(sdk.ReturnResultsBehavior(v.(string))) } if v, ok := d.GetOk("runtime_version"); ok { - request.WithRuntimeVersion(sdk.String(v.(string))) + request.WithRuntimeVersion(v.(string)) } if v, ok := d.GetOk("comment"); ok { - request.WithComment(sdk.String(v.(string))) + request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { imports := []sdk.FunctionImportRequest{} @@ -265,7 +264,7 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf request.WithPackages(packages) } if v, ok := d.GetOk("target_path"); ok { - request.WithTargetPath(sdk.String(v.(string))) + request.WithTargetPath(v.(string)) } if err := client.Functions.CreateForJava(ctx, request); err != nil { @@ -275,7 +274,7 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf for _, item := range arguments { argumentTypes = append(argumentTypes, item.ArgDataType) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -297,11 +296,11 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter handler := d.Get("handler").(string) // create request with required request := sdk.NewCreateForScalaFunctionRequest(id, returnDataType, handler) - request.WithFunctionDefinition(sdk.String(functionDefinition)) + request.WithFunctionDefinition(functionDefinition) // Set optionals if v, ok := d.GetOk("is_secure"); ok { - request.WithSecure(sdk.Bool(v.(bool))) + request.WithSecure(v.(bool)) } arguments, diags := parseFunctionArguments(d) if diags != nil { @@ -311,16 +310,16 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter request.WithArguments(arguments) } if v, ok := d.GetOk("null_input_behavior"); ok { - request.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) + request.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) } if v, ok := d.GetOk("return_behavior"); ok { - request.WithReturnResultsBehavior(sdk.Pointer(sdk.ReturnResultsBehavior(v.(string)))) + request.WithReturnResultsBehavior(sdk.ReturnResultsBehavior(v.(string))) } if v, ok := d.GetOk("runtime_version"); ok { - request.WithRuntimeVersion(sdk.String(v.(string))) + request.WithRuntimeVersion(v.(string)) } if v, ok := d.GetOk("comment"); ok { - request.WithComment(sdk.String(v.(string))) + request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { imports := []sdk.FunctionImportRequest{} @@ -337,7 +336,7 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter request.WithPackages(packages) } if v, ok := d.GetOk("target_path"); ok { - request.WithTargetPath(sdk.String(v.(string))) + request.WithTargetPath(v.(string)) } if err := client.Functions.CreateForScala(ctx, request); err != nil { @@ -347,7 +346,7 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter for _, item := range arguments { argumentTypes = append(argumentTypes, item.ArgDataType) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -370,7 +369,7 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa // Set optionals if v, ok := d.GetOk("is_secure"); ok { - request.WithSecure(sdk.Bool(v.(bool))) + request.WithSecure(v.(bool)) } arguments, diags := parseFunctionArguments(d) if diags != nil { @@ -380,10 +379,10 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa request.WithArguments(arguments) } if v, ok := d.GetOk("return_behavior"); ok { - request.WithReturnResultsBehavior(sdk.Pointer(sdk.ReturnResultsBehavior(v.(string)))) + request.WithReturnResultsBehavior(sdk.ReturnResultsBehavior(v.(string))) } if v, ok := d.GetOk("comment"); ok { - request.WithComment(sdk.String(v.(string))) + request.WithComment(v.(string)) } if err := client.Functions.CreateForSQL(ctx, request); err != nil { @@ -393,7 +392,7 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa for _, item := range arguments { argumentTypes = append(argumentTypes, item.ArgDataType) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -415,11 +414,11 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte handler := d.Get("handler").(string) // create request with required request := sdk.NewCreateForPythonFunctionRequest(id, *returns, version, handler) - request.WithFunctionDefinition(sdk.String(functionDefinition)) + request.WithFunctionDefinition(functionDefinition) // Set optionals if v, ok := d.GetOk("is_secure"); ok { - request.WithSecure(sdk.Bool(v.(bool))) + request.WithSecure(v.(bool)) } arguments, diags := parseFunctionArguments(d) if diags != nil { @@ -429,14 +428,14 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte request.WithArguments(arguments) } if v, ok := d.GetOk("null_input_behavior"); ok { - request.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) + request.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) } if v, ok := d.GetOk("return_behavior"); ok { - request.WithReturnResultsBehavior(sdk.Pointer(sdk.ReturnResultsBehavior(v.(string)))) + request.WithReturnResultsBehavior(sdk.ReturnResultsBehavior(v.(string))) } if v, ok := d.GetOk("comment"); ok { - request.WithComment(sdk.String(v.(string))) + request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { imports := []sdk.FunctionImportRequest{} @@ -460,7 +459,7 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte for _, item := range arguments { argumentTypes = append(argumentTypes, item.ArgDataType) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -483,7 +482,7 @@ func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta // Set optionals if v, ok := d.GetOk("is_secure"); ok { - request.WithSecure(sdk.Bool(v.(bool))) + request.WithSecure(v.(bool)) } arguments, diags := parseFunctionArguments(d) if diags != nil { @@ -493,13 +492,13 @@ func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta request.WithArguments(arguments) } if v, ok := d.GetOk("null_input_behavior"); ok { - request.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) + request.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) } if v, ok := d.GetOk("return_behavior"); ok { - request.WithReturnResultsBehavior(sdk.Pointer(sdk.ReturnResultsBehavior(v.(string)))) + request.WithReturnResultsBehavior(sdk.ReturnResultsBehavior(v.(string))) } if v, ok := d.GetOk("comment"); ok { - request.WithComment(sdk.String(v.(string))) + request.WithComment(v.(string)) } if err := client.Functions.CreateForJavascript(ctx, request); err != nil { @@ -509,7 +508,7 @@ func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta for _, item := range arguments { argumentTypes = append(argumentTypes, item.ArgDataType) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -518,7 +517,10 @@ func ReadContextFunction(ctx context.Context, d *schema.ResourceData, meta inter diags := diag.Diagnostics{} client := meta.(*provider.Context).Client - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } if err := d.Set("name", id.Name()); err != nil { return diag.FromErr(err) } @@ -534,7 +536,7 @@ func ReadContextFunction(ctx context.Context, d *schema.ResourceData, meta inter for i, arg := range arguments { argumentTypes[i] = arg.(map[string]interface{})["type"].(string) } - functionDetails, err := client.Functions.Describe(ctx, sdk.NewDescribeFunctionRequest(id.WithoutArguments(), id.Arguments())) + functionDetails, err := client.Functions.Describe(ctx, id) if err != nil { // if function is not found then mark resource to be removed from state file during apply or refresh d.SetId("") @@ -630,37 +632,34 @@ func ReadContextFunction(ctx context.Context, d *schema.ResourceData, meta inter } } - // Show functions to set is_secure and comment - request := sdk.NewShowFunctionRequest().WithIn(&sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(id.DatabaseName(), id.SchemaName())}).WithLike(&sdk.Like{Pattern: sdk.String(id.Name())}) - functions, err := client.Functions.Show(ctx, request) + function, err := client.Functions.ShowByID(ctx, id) if err != nil { return diag.FromErr(err) } - for _, function := range functions { - signature := strings.Split(function.Arguments, " RETURN ")[0] - signature = strings.ReplaceAll(signature, " ", "") - id.FullyQualifiedName() - if signature == id.ArgumentsSignature() { - if err := d.Set("is_secure", function.IsSecure); err != nil { - return diag.FromErr(err) - } - if err := d.Set("comment", function.Description); err != nil { - return diag.FromErr(err) - } - } + + if err := d.Set("is_secure", function.IsSecure); err != nil { + return diag.FromErr(err) } + + if err := d.Set("comment", function.Description); err != nil { + return diag.FromErr(err) + } + return diags } func UpdateContextFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } if d.HasChange("name") { name := d.Get("name").(string) - newId := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), name, id.Arguments()) + newId := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), name, id.ArgumentDataTypes()...) - if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id.WithoutArguments(), id.Arguments()).WithRenameTo(sdk.Pointer(newId.WithoutArguments()))); err != nil { + if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithRenameTo(newId.SchemaObjectId())); err != nil { return diag.FromErr(err) } @@ -671,11 +670,11 @@ func UpdateContextFunction(ctx context.Context, d *schema.ResourceData, meta int if d.HasChange("is_secure") { secure := d.Get("is_secure") if secure.(bool) { - if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id.WithoutArguments(), id.Arguments()).WithSetSecure(sdk.Bool(true))); err != nil { + if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetSecure(true)); err != nil { return diag.FromErr(err) } } else { - if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id.WithoutArguments(), id.Arguments()).WithUnsetSecure(sdk.Bool(true))); err != nil { + if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithUnsetSecure(true)); err != nil { return diag.FromErr(err) } } @@ -684,11 +683,11 @@ func UpdateContextFunction(ctx context.Context, d *schema.ResourceData, meta int if d.HasChange("comment") { comment := d.Get("comment") if comment != "" { - if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id.WithoutArguments(), id.Arguments()).WithSetComment(sdk.String(comment.(string)))); err != nil { + if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetComment(comment.(string))); err != nil { return diag.FromErr(err) } } else { - if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id.WithoutArguments(), id.Arguments()).WithUnsetComment(sdk.Bool(true))); err != nil { + if err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithUnsetComment(true)); err != nil { return diag.FromErr(err) } } @@ -700,8 +699,11 @@ func UpdateContextFunction(ctx context.Context, d *schema.ResourceData, meta int func DeleteContextFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - if err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id.WithoutArguments(), id.Arguments())); err != nil { + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } + if err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id)); err != nil { return diag.FromErr(err) } d.SetId("") @@ -762,13 +764,13 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di for _, item := range columns { cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, item.ColumnDataType)) } - returns.WithTable(sdk.NewFunctionReturnsTableRequest().WithColumns(cr)) + returns.WithTable(*sdk.NewFunctionReturnsTableRequest().WithColumns(cr)) } else { returnDataType, diags := convertFunctionDataType(s) if diags != nil { return nil, diags } - returns.WithResultDataType(sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType)) + returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType)) } return returns, nil } diff --git a/pkg/resources/function_acceptance_test.go b/pkg/resources/function_acceptance_test.go index 224d5985a7..b4e99be301 100644 --- a/pkg/resources/function_acceptance_test.go +++ b/pkg/resources/function_acceptance_test.go @@ -185,7 +185,7 @@ func TestAcc_Function_complex(t *testing.T) { // proves issue https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2490 func TestAcc_Function_migrateFromVersion085(t *testing.T) { - id := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArguments([]sdk.DataType{sdk.DataTypeVARCHAR}) + id := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) name := id.Name() comment := random.Comment() resourceName := "snowflake_function.f" @@ -218,8 +218,13 @@ func TestAcc_Function_migrateFromVersion085(t *testing.T) { ), }, { - ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, - Config: functionConfig(acc.TestDatabaseName, acc.TestSchemaName, name, comment), + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: functionConfig(acc.TestDatabaseName, acc.TestSchemaName, name, comment), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr(resourceName, "id", id.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "name", name), @@ -231,6 +236,40 @@ func TestAcc_Function_migrateFromVersion085(t *testing.T) { }) } +func TestAcc_Function_EnsureSmoothResourceIdMigrationToV0950(t *testing.T) { + name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() + resourceName := "snowflake_function.f" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.Function), + Steps: []resource.TestStep{ + { + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: functionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: functionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + }, + }) +} + func TestAcc_Function_Rename(t *testing.T) { name := acc.TestClient().Ids.Alpha() newName := acc.TestClient().Ids.Alpha() @@ -272,6 +311,32 @@ func TestAcc_Function_Rename(t *testing.T) { }) } +func functionConfigWithMoreArguments(database string, schema string, name string) string { + return fmt.Sprintf(` +resource "snowflake_function" "f" { + database = "%[1]s" + schema = "%[2]s" + name = "%[3]s" + return_type = "VARCHAR" + return_behavior = "IMMUTABLE" + statement = "SELECT A" + + arguments { + name = "A" + type = "VARCHAR" + } + arguments { + name = "B" + type = "FLOAT" + } + arguments { + name = "C" + type = "NUMBER" + } +} +`, database, schema, name) +} + func functionConfig(database string, schema string, name string, comment string) string { return fmt.Sprintf(` resource "snowflake_function" "f" { diff --git a/pkg/resources/function_state_upgraders.go b/pkg/resources/function_state_upgraders.go index 43f699e5d5..501e44f1dc 100644 --- a/pkg/resources/function_state_upgraders.go +++ b/pkg/resources/function_state_upgraders.go @@ -55,7 +55,7 @@ func v085FunctionIdStateUpgrader(ctx context.Context, rawState map[string]interf argDataTypes[i] = argDataType } - schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(parsedV085FunctionId.DatabaseName, parsedV085FunctionId.SchemaName, parsedV085FunctionId.FunctionName, argDataTypes) + schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArgumentsOld(parsedV085FunctionId.DatabaseName, parsedV085FunctionId.SchemaName, parsedV085FunctionId.FunctionName, argDataTypes) rawState["id"] = schemaObjectIdentifierWithArguments.FullyQualifiedName() return rawState, nil diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index 3f52075ee0..118bcde253 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" @@ -281,7 +280,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter for _, item := range args { argTypes = append(argTypes, item.ArgDataType) } - sid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argTypes) + sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) d.SetId(sid.FullyQualifiedName()) return ReadContextProcedure(ctx, d, meta) } @@ -333,7 +332,7 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta for _, item := range args { argTypes = append(argTypes, item.ArgDataType) } - sid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argTypes) + sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) d.SetId(sid.FullyQualifiedName()) return ReadContextProcedure(ctx, d, meta) } @@ -395,7 +394,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte for _, item := range args { argTypes = append(argTypes, item.ArgDataType) } - sid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argTypes) + sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) d.SetId(sid.FullyQualifiedName()) return ReadContextProcedure(ctx, d, meta) } @@ -446,7 +445,7 @@ func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interf for _, item := range args { argTypes = append(argTypes, item.ArgDataType) } - sid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argTypes) + sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) d.SetId(sid.FullyQualifiedName()) return ReadContextProcedure(ctx, d, meta) } @@ -516,7 +515,7 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int for _, item := range args { argTypes = append(argTypes, item.ArgDataType) } - sid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argTypes) + sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) d.SetId(sid.FullyQualifiedName()) return ReadContextProcedure(ctx, d, meta) } @@ -657,7 +656,7 @@ func UpdateContextProcedure(ctx context.Context, d *schema.ResourceData, meta in id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) if d.HasChange("name") { - newId := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), d.Get("name").(string), id.Arguments()) + newId := sdk.NewSchemaObjectIdentifierWithArgumentsOld(id.DatabaseName(), id.SchemaName(), d.Get("name").(string), id.Arguments()) err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithRenameTo(sdk.Pointer(newId.WithoutArguments()))) if err != nil { diff --git a/pkg/resources/procedure_acceptance_test.go b/pkg/resources/procedure_acceptance_test.go index 4eeecabfd8..29e89243c1 100644 --- a/pkg/resources/procedure_acceptance_test.go +++ b/pkg/resources/procedure_acceptance_test.go @@ -6,7 +6,6 @@ import ( "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/config" @@ -245,20 +244,20 @@ func TestAcc_Procedure_migrateFromVersion085(t *testing.T) { func procedureConfig(database string, schema string, name string) string { return fmt.Sprintf(` resource "snowflake_procedure" "p" { - database = "%[1]s" - schema = "%[2]s" - name = "%[3]s" - language = "JAVASCRIPT" - return_type = "VARCHAR" - statement = <, n) where right now can be either INT or FLOAT and n is the number of elements in the VECTOR. +// Snowflake returns vectors with their exact type and this function supports it. +func ParseFunctionArgumentsFromString(arguments string) ([]DataType, error) { + dataTypes := make([]DataType, 0) + + if len(arguments) > 0 && arguments[0] == '(' && arguments[len(arguments)-1] == ')' { + arguments = arguments[1 : len(arguments)-1] + } + stringBuffer := bytes.NewBufferString(arguments) + + for stringBuffer.Len() > 0 { + // We use another buffer to peek into next data type (needed for vector parsing) + peekDataType, _ := bytes.NewBufferString(stringBuffer.String()).ReadString(',') + peekDataType = strings.TrimSpace(peekDataType) + + switch { + // For now, only vectors need special parsing behavior + case strings.HasPrefix(peekDataType, "VECTOR"): + vectorDataType, err := stringBuffer.ReadString(')') + if err != nil { + return nil, fmt.Errorf("failed to parse vector type, couldn't find the closing bracket, err = %w", err) + } + + vectorDataType = strings.TrimSpace(vectorDataType) + vectorTypeBuffer := bytes.NewBufferString(vectorDataType) + if _, err := vectorTypeBuffer.ReadString('('); err != nil { + return nil, fmt.Errorf("failed to parse vector type, couldn't find the opening bracket, err = %w", err) + } + + vectorInnerType, err := vectorTypeBuffer.ReadString(',') + if err != nil { + return nil, fmt.Errorf("failed to parse vector inner type: %w", err) + } + + vectorInnerType = vectorInnerType[:len(vectorInnerType)-1] + if !slices.Contains(allowedVectorInnerTypes, DataType(vectorInnerType)) { + return nil, fmt.Errorf("invalid vector inner type: %s, allowed vector types are: %v", vectorInnerType, allowedVectorInnerTypes) + } + + vectorSize, err := vectorTypeBuffer.ReadString(')') + if err != nil { + return nil, fmt.Errorf("failed to parse vector size: %w", err) + } + + vectorSize = strings.TrimSpace(vectorSize[:len(vectorSize)-1]) + _, err = strconv.ParseInt(vectorSize, 0, 8) + if err != nil { + return nil, fmt.Errorf("invalid vector size: %s (not a number): %w", vectorSize, err) + } + + if stringBuffer.Len() > 0 { + commaByte, err := stringBuffer.ReadByte() + if commaByte != ',' { + return nil, fmt.Errorf("expected a comma delimited string but found %s", string(commaByte)) + } + if err != nil { + return nil, err + } + } + dataTypes = append(dataTypes, DataType(vectorDataType)) + default: + dataType, err := stringBuffer.ReadString(',') + if err == nil { + dataType = dataType[:len(dataType)-1] + } + dataType = strings.TrimSpace(dataType) + dataTypes = append(dataTypes, DataType(dataType)) + } + } + + return dataTypes, nil +} diff --git a/pkg/sdk/identifier_parsers_test.go b/pkg/sdk/identifier_parsers_test.go index de86bbc9dc..078cbbc30d 100644 --- a/pkg/sdk/identifier_parsers_test.go +++ b/pkg/sdk/identifier_parsers_test.go @@ -80,6 +80,27 @@ func Test_ParseIdentifierString(t *testing.T) { require.ErrorContains(t, err, `unable to parse identifier: "ab""c".def, currently identifiers containing double quotes are not supported in the provider`) }) + t.Run("returns error when identifier contains opening parenthesis", func(t *testing.T) { + input := `"ab(c".def` + _, err := ParseIdentifierString(input) + + require.ErrorContains(t, err, `unable to parse identifier: "ab(c".def, currently identifiers containing opening and closing parentheses '()' are not supported in the provider`) + }) + + t.Run("returns error when identifier contains closing parenthesis", func(t *testing.T) { + input := `"ab)c".def` + _, err := ParseIdentifierString(input) + + require.ErrorContains(t, err, `unable to parse identifier: "ab)c".def, currently identifiers containing opening and closing parentheses '()' are not supported in the provider`) + }) + + t.Run("returns error when identifier contains opening and closing parentheses", func(t *testing.T) { + input := `"ab()c".def` + _, err := ParseIdentifierString(input) + + require.ErrorContains(t, err, `unable to parse identifier: "ab()c".def, currently identifiers containing opening and closing parentheses '()' are not supported in the provider`) + }) + t.Run("returns parts correctly with dots inside", func(t *testing.T) { input := `"ab.c".def` expected := []string{`ab.c`, "def"} @@ -250,3 +271,101 @@ func Test_ParseObjectIdentifierString(t *testing.T) { }) } } + +func Test_ParseFunctionArgumentsFromString(t *testing.T) { + testCases := []struct { + Arguments string + Expected []DataType + Error string + }{ + {Arguments: `()`, Expected: []DataType{}}, + {Arguments: `(FLOAT, NUMBER, TIME)`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataTypeTime}}, + {Arguments: `FLOAT, NUMBER, TIME`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataTypeTime}}, + {Arguments: `(FLOAT, NUMBER, VECTOR(FLOAT, 20))`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataType("VECTOR(FLOAT, 20)")}}, + {Arguments: `FLOAT, NUMBER, VECTOR(FLOAT, 20)`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataType("VECTOR(FLOAT, 20)")}}, + {Arguments: `(VECTOR(FLOAT, 10), NUMBER, VECTOR(FLOAT, 20))`, Expected: []DataType{DataType("VECTOR(FLOAT, 10)"), DataTypeNumber, DataType("VECTOR(FLOAT, 20)")}}, + {Arguments: `VECTOR(FLOAT, 10)| NUMBER, VECTOR(FLOAT, 20)`, Error: "expected a comma delimited string but found |"}, + {Arguments: `FLOAT, NUMBER, VECTORFLOAT, 20)`, Error: `failed to parse vector type, couldn't find the opening bracket, err = EOF`}, + {Arguments: `FLOAT, NUMBER, VECTORFLOAT, 20), VECTOR(INT, 10)`, Error: `failed to parse vector type, couldn't find the opening bracket, err = EOF`}, + {Arguments: `FLOAT, NUMBER, VECTOR(FLOAT, 20`, Error: `failed to parse vector type, couldn't find the closing bracket, err = EOF`}, + {Arguments: `FLOAT, NUMBER, VECTOR(FLOAT, 20, VECTOR(INT, 10)`, Error: `invalid vector size: 20, VECTOR(INT, 10 (not a number): strconv.ParseInt: parsing "20, VECTOR(INT, 10": invalid syntax`}, + {Arguments: `(FLOAT, VARCHAR(200), TIME)`, Expected: []DataType{DataTypeFloat, DataType("VARCHAR(200)"), DataTypeTime}}, + {Arguments: `(FLOAT, VARCHAR(200))`, Expected: []DataType{DataTypeFloat, DataType("VARCHAR(200)")}}, + {Arguments: `(VARCHAR(200), FLOAT)`, Expected: []DataType{DataType("VARCHAR(200)"), DataTypeFloat}}, + {Arguments: `(FLOAT, NUMBER, VECTOR(VARCHAR, 20))`, Error: `invalid vector inner type: VARCHAR, allowed vector types are`}, + {Arguments: `(FLOAT, NUMBER, VECTOR(INT, INT))`, Error: `invalid vector size: INT (not a number): strconv.ParseInt: parsing "INT": invalid syntax`}, + {Arguments: `FLOAT, NUMBER, VECTOR(20, FLOAT)`, Error: `invalid vector inner type: 20, allowed vector types are`}, + // As the function is only used for identifiers with arguments the following cases are not supported (because they represent concrete types which are not used as part of the identifiers). + {Arguments: `(FLOAT, NUMBER(10, 2), TIME)`, Expected: []DataType{DataTypeFloat, DataType("NUMBER(10"), DataType("2)"), DataTypeTime}}, + {Arguments: `(FLOAT, NUMBER(10, 2))`, Expected: []DataType{DataTypeFloat, DataType("NUMBER(10"), DataType("2)")}}, + {Arguments: `(NUMBER(10, 2), FLOAT)`, Expected: []DataType{DataType("NUMBER(10"), DataType("2)"), DataTypeFloat}}, + } + + for _, testCase := range testCases { + t.Run(fmt.Sprintf("parsing function arguments %s", testCase.Arguments), func(t *testing.T) { + dataTypes, err := ParseFunctionArgumentsFromString(testCase.Arguments) + if testCase.Error != "" { + assert.ErrorContains(t, err, testCase.Error) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.Expected, dataTypes) + } + }) + } +} + +func TestNewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(t *testing.T) { + testCases := []struct { + Input SchemaObjectIdentifierWithArguments + Error string + }{ + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, DataTypeNumber, DataTypeTimestampTZ)}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, "VECTOR(INT, 20)", DataTypeFloat)}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)", "VECTOR(INT, 10)")}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat)}, + // TODO(SNOW-1571674): Won't work, because of the assumption that identifiers are not containing '(' and ')' parentheses (unfortunately, we're not able to produce meaningful errors for those cases) + {Input: NewSchemaObjectIdentifierWithArguments(`ab()c`, `def()`, `()ghi`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat), Error: `unable to read identifier: "ab`}, + {Input: NewSchemaObjectIdentifierWithArguments(`ab(,)c`, `,def()`, `()ghi,`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat), Error: `unable to read identifier: "ab`}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`)}, + } + + for _, testCase := range testCases { + t.Run(fmt.Sprintf("processing %s", testCase.Input.FullyQualifiedName()), func(t *testing.T) { + id, err := ParseSchemaObjectIdentifierWithArguments(testCase.Input.FullyQualifiedName()) + + if testCase.Error != "" { + assert.ErrorContains(t, err, testCase.Error) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.Input.FullyQualifiedName(), id.FullyQualifiedName()) + } + }) + } +} + +func TestNewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName_WithRawInput(t *testing.T) { + testCases := []struct { + RawInput string + ExpectedIdentifierStructure SchemaObjectIdentifierWithArguments + Error string + }{ + {RawInput: `abc.def.ghi()`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`)}, + {RawInput: `abc.def.ghi(FLOAT, VECTOR(INT, 20))`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, + // TODO(SNOW-1571674): Won't work, because of the assumption that identifiers are not containing '(' and ')' parentheses (unfortunately, we're not able to produce meaningful errors for those cases) + {RawInput: `abc."(ef".ghi(FLOAT, VECTOR(INT, 20))`, Error: `unable to read identifier: abc."`}, + } + + for _, testCase := range testCases { + t.Run(fmt.Sprintf("processing %s", testCase.ExpectedIdentifierStructure.FullyQualifiedName()), func(t *testing.T) { + id, err := ParseSchemaObjectIdentifierWithArguments(testCase.RawInput) + + if testCase.Error != "" { + assert.ErrorContains(t, err, testCase.Error) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.ExpectedIdentifierStructure.FullyQualifiedName(), id.FullyQualifiedName()) + } + }) + } +} diff --git a/pkg/sdk/random_test.go b/pkg/sdk/random_test.go index 8b7c981744..0f58dc3a79 100644 --- a/pkg/sdk/random_test.go +++ b/pkg/sdk/random_test.go @@ -9,11 +9,16 @@ var ( longSchemaObjectIdentifier = NewSchemaObjectIdentifier(random.StringN(255), random.StringN(255), random.StringN(255)) // TODO: Add to the generator - emptyAccountObjectIdentifier = NewAccountObjectIdentifier("") - emptyDatabaseObjectIdentifier = NewDatabaseObjectIdentifier("", "") - emptySchemaObjectIdentifier = NewSchemaObjectIdentifier("", "", "") + emptyAccountObjectIdentifier = NewAccountObjectIdentifier("") + emptyDatabaseObjectIdentifier = NewDatabaseObjectIdentifier("", "") + emptySchemaObjectIdentifier = NewSchemaObjectIdentifier("", "", "") + emptySchemaObjectIdentifierWithArguments = NewSchemaObjectIdentifierWithArguments("", "", "") ) +func randomSchemaObjectIdentifierWithArguments(argumentDataTypes ...DataType) SchemaObjectIdentifierWithArguments { + return NewSchemaObjectIdentifierWithArguments(random.StringN(12), random.StringN(12), random.StringN(12), argumentDataTypes...) +} + func randomSchemaObjectIdentifier() SchemaObjectIdentifier { return NewSchemaObjectIdentifier(random.StringN(12), random.StringN(12), random.StringN(12)) } diff --git a/pkg/sdk/testint/external_functions_integration_test.go b/pkg/sdk/testint/external_functions_integration_test.go index 9787dd7f2e..2901d5e3f7 100644 --- a/pkg/sdk/testint/external_functions_integration_test.go +++ b/pkg/sdk/testint/external_functions_integration_test.go @@ -20,7 +20,7 @@ func TestInt_ExternalFunctions(t *testing.T) { cleanupExternalFunctionHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { return func() { - err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id.WithoutArguments(), dts).WithIfExists(sdk.Bool(true))) + err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), dts...)).WithIfExists(true)) require.NoError(t, err) } } @@ -28,7 +28,7 @@ func TestInt_ExternalFunctions(t *testing.T) { // TODO [SNOW-999049]: id returned on purpose; address during identifiers rework createExternalFunction := func(t *testing.T) (*sdk.ExternalFunction, sdk.SchemaObjectIdentifier) { t.Helper() - id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(defaultDataTypes) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(defaultDataTypes...) argument := sdk.NewExternalFunctionArgumentRequest("x", defaultDataTypes[0]) as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as). @@ -77,7 +77,7 @@ func TestInt_ExternalFunctions(t *testing.T) { } t.Run("create external function", func(t *testing.T) { - id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(defaultDataTypes) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(defaultDataTypes...) argument := sdk.NewExternalFunctionArgumentRequest("x", sdk.DataTypeVARCHAR) headers := []sdk.ExternalFunctionHeaderRequest{ { @@ -111,7 +111,7 @@ func TestInt_ExternalFunctions(t *testing.T) { }) t.Run("create external function without arguments", func(t *testing.T) { - id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(nil) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArgumentsOld() as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as) err := client.ExternalFunctions.Create(ctx, request) diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index 205689be40..e3a945934d 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -7,9 +7,10 @@ import ( "testing" "time" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" "github.com/stretchr/testify/assert" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/stretchr/testify/require" ) @@ -23,9 +24,9 @@ func TestInt_CreateFunctions(t *testing.T) { client := testClient(t) ctx := context.Background() - cleanupFunctionHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { + cleanupFunctionHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id, dts)) + err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id)) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { return } @@ -34,8 +35,7 @@ func TestInt_CreateFunctions(t *testing.T) { } t.Run("create function for Java", func(t *testing.T) { - name := "echo_varchar" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) definition := ` class TestFunc { @@ -45,17 +45,17 @@ func TestInt_CreateFunctions(t *testing.T) { }` target := fmt.Sprintf("@~/tf-%d.jar", time.Now().Unix()) dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(dt) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeVARCHAR).WithDefaultValue(sdk.String("'abc'")) - request := sdk.NewCreateForJavaFunctionRequest(id, *returns, "TestFunc.echoVarchar"). - WithOrReplace(sdk.Bool(true)). + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeVARCHAR).WithDefaultValue("'abc'") + request := sdk.NewCreateForJavaFunctionRequest(id.SchemaObjectId(), *returns, "TestFunc.echoVarchar"). + WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). - WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorCalledOnNullInput)). - WithTargetPath(&target). - WithFunctionDefinition(&definition) + WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorCalledOnNullInput)). + WithTargetPath(target). + WithFunctionDefinition(definition) err := client.Functions.CreateForJava(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{"VARCHAR"})) + t.Cleanup(cleanupFunctionHandle(id)) function, err := client.Functions.ShowByID(ctx, id) require.NoError(t, err) @@ -64,8 +64,7 @@ func TestInt_CreateFunctions(t *testing.T) { }) t.Run("create function for Javascript", func(t *testing.T) { - name := "js_factorial" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeFloat) definition := ` if (D <= 0) { @@ -79,15 +78,15 @@ func TestInt_CreateFunctions(t *testing.T) { }` dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(dt) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) argument := sdk.NewFunctionArgumentRequest("d", sdk.DataTypeFloat) - request := sdk.NewCreateForJavascriptFunctionRequest(id, *returns, definition). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForJavascriptFunctionRequest(id.SchemaObjectId(), *returns, definition). + WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). - WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorCalledOnNullInput)) + WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorCalledOnNullInput)) err := client.Functions.CreateForJavascript(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{sdk.DataTypeFloat})) + t.Cleanup(cleanupFunctionHandle(id)) function, err := client.Functions.ShowByID(ctx, id) require.NoError(t, err) @@ -96,21 +95,21 @@ func TestInt_CreateFunctions(t *testing.T) { }) t.Run("create function for Python", func(t *testing.T) { - id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeNumber) definition := ` def dump(i): print("Hello World!")` dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVariant) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(dt) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) argument := sdk.NewFunctionArgumentRequest("i", sdk.DataTypeNumber) - request := sdk.NewCreateForPythonFunctionRequest(id, *returns, "3.8", "dump"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForPythonFunctionRequest(id.SchemaObjectId(), *returns, "3.8", "dump"). + WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). - WithFunctionDefinition(&definition) + WithFunctionDefinition(definition) err := client.Functions.CreateForPython(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{"int"})) + t.Cleanup(cleanupFunctionHandle(id)) function, err := client.Functions.ShowByID(ctx, id) require.NoError(t, err) @@ -119,8 +118,7 @@ def dump(i): }) t.Run("create function for Scala", func(t *testing.T) { - name := "echo_varchar" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) definition := ` class Echo { @@ -130,14 +128,14 @@ def dump(i): }` argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForScalaFunctionRequest(id, sdk.DataTypeVARCHAR, "Echo.echoVarchar"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForScalaFunctionRequest(id.SchemaObjectId(), sdk.DataTypeVARCHAR, "Echo.echoVarchar"). + WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). - WithRuntimeVersion(sdk.String("2.12")). - WithFunctionDefinition(&definition) + WithRuntimeVersion("2.12"). + WithFunctionDefinition(definition) err := client.Functions.CreateForScala(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupFunctionHandle(id)) function, err := client.Functions.ShowByID(ctx, id) require.NoError(t, err) @@ -146,20 +144,20 @@ def dump(i): }) t.Run("create function for SQL", func(t *testing.T) { - id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeFloat) definition := "3.141592654::FLOAT" dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(dt) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) - request := sdk.NewCreateForSQLFunctionRequest(id, *returns, definition). + request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). WithArguments([]sdk.FunctionArgumentRequest{*argument}). - WithOrReplace(sdk.Bool(true)). - WithComment(sdk.String("comment")) + WithOrReplace(true). + WithComment("comment") err := client.Functions.CreateForSQL(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{sdk.DataTypeFloat})) + t.Cleanup(cleanupFunctionHandle(id)) function, err := client.Functions.ShowByID(ctx, id) require.NoError(t, err) @@ -168,18 +166,18 @@ def dump(i): }) t.Run("create function for SQL with no arguments", func(t *testing.T) { - id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments() definition := "3.141592654::FLOAT" dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(dt) - request := sdk.NewCreateForSQLFunctionRequest(id, *returns, definition). - WithOrReplace(sdk.Bool(true)). - WithComment(sdk.String("comment")) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). + WithOrReplace(true). + WithComment("comment") err := client.Functions.CreateForSQL(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupFunctionHandle(id, nil)) + t.Cleanup(cleanupFunctionHandle(id)) function, err := client.Functions.ShowByID(ctx, id) require.NoError(t, err) @@ -195,7 +193,7 @@ func TestInt_OtherFunctions(t *testing.T) { tagTest, tagCleanup := testClientHelper().Tag.CreateTag(t) t.Cleanup(tagCleanup) - assertFunction := func(t *testing.T, id sdk.SchemaObjectIdentifier, secure bool, withArguments bool) { + assertFunction := func(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments, secure bool, withArguments bool) { t.Helper() function, err := client.Functions.ShowByID(ctx, id) @@ -213,6 +211,7 @@ func TestInt_OtherFunctions(t *testing.T) { assert.Equal(t, 0, function.MinNumArguments) assert.Equal(t, 0, function.MaxNumArguments) } + assert.NotEmpty(t, function.ArgumentsRaw) assert.NotEmpty(t, function.Arguments) assert.NotEmpty(t, function.Description) assert.NotEmpty(t, function.CatalogName) @@ -224,9 +223,9 @@ func TestInt_OtherFunctions(t *testing.T) { assert.Equal(t, false, function.IsMemoizable) } - cleanupFunctionHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { + cleanupFunctionHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id, dts)) + err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id)) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { return } @@ -236,14 +235,19 @@ func TestInt_OtherFunctions(t *testing.T) { createFunctionForSQLHandle := func(t *testing.T, cleanup bool, withArguments bool) *sdk.Function { t.Helper() - id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + var id sdk.SchemaObjectIdentifierWithArguments + if withArguments { + id = testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeFloat) + } else { + id = testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments() + } definition := "3.141592654::FLOAT" dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(dt) - request := sdk.NewCreateForSQLFunctionRequest(id, *returns, definition). - WithOrReplace(sdk.Bool(true)) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). + WithOrReplace(true) if withArguments { argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) request = request.WithArguments([]sdk.FunctionArgumentRequest{*argument}) @@ -251,31 +255,23 @@ func TestInt_OtherFunctions(t *testing.T) { err := client.Functions.CreateForSQL(ctx, request) require.NoError(t, err) if cleanup { - if withArguments { - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{sdk.DataTypeFloat})) - } else { - t.Cleanup(cleanupFunctionHandle(id, nil)) - } + t.Cleanup(cleanupFunctionHandle(id)) } function, err := client.Functions.ShowByID(ctx, id) require.NoError(t, err) return function } - defaultAlterRequest := func(id sdk.SchemaObjectIdentifier) *sdk.AlterFunctionRequest { - return sdk.NewAlterFunctionRequest(id, []sdk.DataType{sdk.DataTypeFloat}) - } - t.Run("alter function: rename", func(t *testing.T) { f := createFunctionForSQLHandle(t, false, true) id := f.ID() - nid := testClientHelper().Ids.RandomSchemaObjectIdentifier() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithRenameTo(&nid)) + nid := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments() + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithRenameTo(nid.SchemaObjectId())) if err != nil { - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{sdk.DataTypeFloat})) + t.Cleanup(cleanupFunctionHandle(id)) } else { - t.Cleanup(cleanupFunctionHandle(nid, []sdk.DataType{sdk.DataTypeFloat})) + t.Cleanup(cleanupFunctionHandle(nid)) } require.NoError(t, err) @@ -291,7 +287,7 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithSetLogLevel(sdk.String("DEBUG"))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetLogLevel(string(sdk.LogLevelDebug))) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -300,7 +296,7 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithUnsetLogLevel(sdk.Bool(true))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithUnsetLogLevel(true)) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -309,7 +305,7 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithSetTraceLevel(sdk.String("ALWAYS"))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetTraceLevel(string(sdk.TraceLevelAlways))) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -318,7 +314,7 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithUnsetTraceLevel(sdk.Bool(true))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithUnsetTraceLevel(true)) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -327,7 +323,7 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithSetComment(sdk.String("test comment"))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetComment("test comment")) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -336,7 +332,7 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithUnsetComment(sdk.Bool(true))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithUnsetComment(true)) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -345,24 +341,24 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithSetSecure(sdk.Bool(true))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetSecure(true)) require.NoError(t, err) assertFunction(t, id, true, true) }) t.Run("alter function: set secure with no arguments", func(t *testing.T) { - f := createFunctionForSQLHandle(t, true, false) + f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id, nil).WithSetSecure(sdk.Bool(true))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetSecure(true)) require.NoError(t, err) - assertFunction(t, id, true, false) + assertFunction(t, id, true, true) }) t.Run("alter function: unset secure", func(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) id := f.ID() - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithUnsetSecure(sdk.Bool(true))) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithUnsetSecure(true)) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -377,14 +373,14 @@ func TestInt_OtherFunctions(t *testing.T) { Value: "v1", }, } - err := client.Functions.Alter(ctx, defaultAlterRequest(id).WithSetTags(setTags)) + err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithSetTags(setTags)) require.NoError(t, err) assertFunction(t, id, false, true) unsetTags := []sdk.ObjectIdentifier{ tagTest.ID(), } - err = client.Functions.Alter(ctx, defaultAlterRequest(id).WithUnsetTags(unsetTags)) + err = client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithUnsetTags(unsetTags)) require.NoError(t, err) assertFunction(t, id, false, true) }) @@ -404,7 +400,7 @@ func TestInt_OtherFunctions(t *testing.T) { f1 := createFunctionForSQLHandle(t, true, true) f2 := createFunctionForSQLHandle(t, true, true) - functions, err := client.Functions.Show(ctx, sdk.NewShowFunctionRequest().WithLike(&sdk.Like{Pattern: &f1.Name})) + functions, err := client.Functions.Show(ctx, sdk.NewShowFunctionRequest().WithLike(sdk.Like{Pattern: &f1.Name})) require.NoError(t, err) require.Equal(t, 1, len(functions)) @@ -413,17 +409,15 @@ func TestInt_OtherFunctions(t *testing.T) { }) t.Run("show function for SQL: no matches", func(t *testing.T) { - functions, err := client.Functions.Show(ctx, sdk.NewShowFunctionRequest().WithLike(&sdk.Like{Pattern: sdk.String("non-existing-id-pattern")})) + functions, err := client.Functions.Show(ctx, sdk.NewShowFunctionRequest().WithLike(sdk.Like{Pattern: sdk.String("non-existing-id-pattern")})) require.NoError(t, err) require.Equal(t, 0, len(functions)) }) t.Run("describe function for SQL", func(t *testing.T) { f := createFunctionForSQLHandle(t, true, true) - id := f.ID() - request := sdk.NewDescribeFunctionRequest(id, []sdk.DataType{sdk.DataTypeFloat}) - details, err := client.Functions.Describe(ctx, request) + details, err := client.Functions.Describe(ctx, f.ID()) require.NoError(t, err) pairs := make(map[string]string) for _, detail := range details { @@ -437,10 +431,8 @@ func TestInt_OtherFunctions(t *testing.T) { t.Run("describe function for SQL: no arguments", func(t *testing.T) { f := createFunctionForSQLHandle(t, true, false) - id := f.ID() - request := sdk.NewDescribeFunctionRequest(id, nil) - details, err := client.Functions.Describe(ctx, request) + details, err := client.Functions.Describe(ctx, f.ID()) require.NoError(t, err) pairs := make(map[string]string) for _, detail := range details { @@ -457,9 +449,9 @@ func TestInt_FunctionsShowByID(t *testing.T) { client := testClient(t) ctx := testContext(t) - cleanupFunctionHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { + cleanupFunctionHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id, dts)) + err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id)) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { return } @@ -467,37 +459,99 @@ func TestInt_FunctionsShowByID(t *testing.T) { } } - createFunctionForSQLHandle := func(t *testing.T, id sdk.SchemaObjectIdentifier) { + createFunctionForSQLHandle := func(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments) { t.Helper() definition := "3.141592654::FLOAT" dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(dt) - request := sdk.NewCreateForSQLFunctionRequest(id, *returns, definition).WithOrReplace(sdk.Bool(true)) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition).WithOrReplace(true) argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) request = request.WithArguments([]sdk.FunctionArgumentRequest{*argument}) err := client.Functions.CreateForSQL(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupFunctionHandle(id, []sdk.DataType{sdk.DataTypeFloat})) + t.Cleanup(cleanupFunctionHandle(id)) } t.Run("show by id - same name in different schemas", func(t *testing.T) { schema, schemaCleanup := testClientHelper().Schema.CreateSchema(t) t.Cleanup(schemaCleanup) - id1 := testClientHelper().Ids.RandomSchemaObjectIdentifier() - id2 := testClientHelper().Ids.NewSchemaObjectIdentifierInSchema(id1.Name(), schema.ID()) + id1 := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeFloat) + id2 := testClientHelper().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(id1.Name(), schema.ID(), sdk.DataTypeFloat) createFunctionForSQLHandle(t, id1) createFunctionForSQLHandle(t, id2) e1, err := client.Functions.ShowByID(ctx, id1) require.NoError(t, err) - require.Equal(t, id1, e1.ID()) + + e1Id := e1.ID() + require.NoError(t, err) + require.Equal(t, id1, e1Id) e2, err := client.Functions.ShowByID(ctx, id2) require.NoError(t, err) - require.Equal(t, id2, e2.ID()) + + e2Id := e2.ID() + require.NoError(t, err) + require.Equal(t, id2, e2Id) + }) + + t.Run("function returns non detailed data types of arguments", func(t *testing.T) { + // This test proves that every detailed data types (e.g. VARCHAR(20) and NUMBER(10, 0)) are generalized + // on Snowflake side (to e.g. VARCHAR and NUMBER) and that sdk.ToDataType mapping function maps detailed types + // correctly to their generalized counterparts (same as in Snowflake). + + id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + args := []sdk.FunctionArgumentRequest{ + *sdk.NewFunctionArgumentRequest("A", "NUMBER(2, 0)"), + *sdk.NewFunctionArgumentRequest("B", "DECIMAL"), + *sdk.NewFunctionArgumentRequest("C", "INTEGER"), + *sdk.NewFunctionArgumentRequest("D", sdk.DataTypeFloat), + *sdk.NewFunctionArgumentRequest("E", "DOUBLE"), + *sdk.NewFunctionArgumentRequest("F", "VARCHAR(20)"), + *sdk.NewFunctionArgumentRequest("G", "CHAR"), + *sdk.NewFunctionArgumentRequest("H", sdk.DataTypeString), + *sdk.NewFunctionArgumentRequest("I", "TEXT"), + *sdk.NewFunctionArgumentRequest("J", sdk.DataTypeBinary), + *sdk.NewFunctionArgumentRequest("K", "VARBINARY"), + *sdk.NewFunctionArgumentRequest("L", sdk.DataTypeBoolean), + *sdk.NewFunctionArgumentRequest("M", sdk.DataTypeDate), + *sdk.NewFunctionArgumentRequest("N", "DATETIME"), + *sdk.NewFunctionArgumentRequest("O", sdk.DataTypeTime), + *sdk.NewFunctionArgumentRequest("P", sdk.DataTypeTimestamp), + *sdk.NewFunctionArgumentRequest("R", sdk.DataTypeTimestampLTZ), + *sdk.NewFunctionArgumentRequest("S", sdk.DataTypeTimestampNTZ), + *sdk.NewFunctionArgumentRequest("T", sdk.DataTypeTimestampTZ), + *sdk.NewFunctionArgumentRequest("U", sdk.DataTypeVariant), + *sdk.NewFunctionArgumentRequest("V", sdk.DataTypeObject), + *sdk.NewFunctionArgumentRequest("W", sdk.DataTypeArray), + *sdk.NewFunctionArgumentRequest("X", sdk.DataTypeGeography), + *sdk.NewFunctionArgumentRequest("Y", sdk.DataTypeGeometry), + *sdk.NewFunctionArgumentRequest("Z", "VECTOR(INT, 16)"), + } + err := client.Functions.CreateForPython(ctx, sdk.NewCreateForPythonFunctionRequest( + id, + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVariant)), + "3.8", + "add", + ). + WithArguments(args). + WithFunctionDefinition("def add(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, R, S, T, U, V, W, X, Y, Z): A + A"), + ) + require.NoError(t, err) + + dataTypes := make([]sdk.DataType, len(args)) + for i, arg := range args { + dataTypes[i], err = sdk.ToDataType(string(arg.ArgDataType)) + require.NoError(t, err) + } + idWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), dataTypes...) + + function, err := client.Functions.ShowByID(ctx, idWithArguments) + require.NoError(t, err) + require.Equal(t, dataTypes, function.Arguments) }) }