diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index 5476bb6c4b..85d0fe56ce 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -4,13 +4,13 @@ import ( "database/sql" "fmt" "log" - "regexp" "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "github.com/pkg/errors" + "golang.org/x/exp/slices" ) var procedureLanguages = []string{"javascript", "java", "scala", "SQL"} @@ -41,15 +41,19 @@ var procedureSchema = map[string]*schema.Schema{ Type: schema.TypeString, Required: true, // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: DiffTypes, - Description: "The argument name", + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(old, new) + }, + Description: "The argument name", }, "type": { Type: schema.TypeString, Required: true, // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: DiffTypes, - Description: "The argument type", + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(old, new) + }, + Description: "The argument type", }, }, }, @@ -61,9 +65,25 @@ var procedureSchema = map[string]*schema.Schema{ Type: schema.TypeString, Description: "The return type of the procedure", // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: DiffTypes, - Required: true, - ForceNew: true, + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + if strings.EqualFold(old, new) { + return true + } + + varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"} + if slices.Contains(varcharType, strings.ToUpper(old)) && slices.Contains(varcharType, strings.ToUpper(new)) { + return true + } + + // all these types are equivalent https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint + integerTypes := []string{"INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT", "NUMBER(38,0)"} + if slices.Contains(integerTypes, strings.ToUpper(old)) && slices.Contains(integerTypes, strings.ToUpper(new)) { + return true + } + return false + }, + Required: true, + ForceNew: true, }, "statement": { Type: schema.TypeString, @@ -76,10 +96,11 @@ var procedureSchema = map[string]*schema.Schema{ Type: schema.TypeString, Optional: true, Default: "SQL", - // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: DiffTypes, - ValidateFunc: validation.StringInSlice(procedureLanguages, true), - Description: "Specifies the language of the stored procedure code.", + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(old, new) + }, + ValidateFunc: validation.StringInSlice(procedureLanguages, true), + Description: "Specifies the language of the stored procedure code.", }, "execute_as": { Type: schema.TypeString, @@ -268,10 +289,7 @@ func ReadProcedure(d *schema.ResourceData, meta interface{}) error { return err } case "returns": - // Format in Snowflake DB is RETURN_TYPE() or RETURN_TYPE - re := regexp.MustCompile(`^([A-Z0-9_]+)\s?(\([0-9]*\))?$`) - match := re.FindStringSubmatch(desc.Value.String) - if err = d.Set("return_type", match[1]); err != nil { + if err = d.Set("return_type", desc.Value.String); err != nil { return err } case "language": diff --git a/pkg/resources/procedure_acceptance_test.go b/pkg/resources/procedure_acceptance_test.go index 4675b89731..fea54ce661 100644 --- a/pkg/resources/procedure_acceptance_test.go +++ b/pkg/resources/procedure_acceptance_test.go @@ -50,6 +50,8 @@ func TestAcc_Procedure(t *testing.T) { resource.TestCheckResourceAttr("snowflake_procedure.test_proc_complex", "arguments.1.type", "DATE"), resource.TestCheckResourceAttr("snowflake_procedure.test_proc_complex", "return_behavior", "IMMUTABLE"), resource.TestCheckResourceAttr("snowflake_procedure.test_proc_complex", "null_input_behavior", "RETURNS NULL ON NULL INPUT"), + + resource.TestCheckResourceAttr("snowflake_procedure.test_proc_sql", "name", procName+"_sql"), ), }, { @@ -125,5 +127,26 @@ func procedureConfig(db, schema, name string) string { return X EOF } - `, db, schema, name, name, name) + + resource "snowflake_procedure" "test_proc_sql" { + name = "%s_sql" + database = snowflake_database.test_database.name + schema = snowflake_schema.test_schema.name + language = "SQL" + return_type = "INTEGER" + execute_as = "CALLER" + return_behavior = "IMMUTABLE" + null_input_behavior = "RETURNS NULL ON NULL INPUT" + statement = <