Skip to content

Commit

Permalink
feat: integer return type for procedure (#1266)
Browse files Browse the repository at this point in the history
* integer return type for procedure

* integer return type for procedure

* integer return type for procedure

* integer return type for procedure

* integer return type for procedure

* update test

* update test

* update test
  • Loading branch information
sfc-gh-swinkler authored Oct 11, 2022
1 parent aede29f commit c1cf881
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 20 deletions.
50 changes: 34 additions & 16 deletions pkg/resources/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"database/sql"
"fmt"
"log"
"regexp"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/pkg/errors"
"golang.org/x/exp/slices"
)

var procedureLanguages = []string{"javascript", "java", "scala", "SQL"}
Expand Down Expand Up @@ -41,15 +41,19 @@ var procedureSchema = map[string]*schema.Schema{
Type: schema.TypeString,
Required: true,
// Suppress the diff shown if the values are equal when both compared in lower case.
DiffSuppressFunc: DiffTypes,
Description: "The argument name",
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
return strings.EqualFold(old, new)
},
Description: "The argument name",
},
"type": {
Type: schema.TypeString,
Required: true,
// Suppress the diff shown if the values are equal when both compared in lower case.
DiffSuppressFunc: DiffTypes,
Description: "The argument type",
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
return strings.EqualFold(old, new)
},
Description: "The argument type",
},
},
},
Expand All @@ -61,9 +65,25 @@ var procedureSchema = map[string]*schema.Schema{
Type: schema.TypeString,
Description: "The return type of the procedure",
// Suppress the diff shown if the values are equal when both compared in lower case.
DiffSuppressFunc: DiffTypes,
Required: true,
ForceNew: true,
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
if strings.EqualFold(old, new) {
return true
}

varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"}
if slices.Contains(varcharType, strings.ToUpper(old)) && slices.Contains(varcharType, strings.ToUpper(new)) {
return true
}

// all these types are equivalent https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint
integerTypes := []string{"INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT", "NUMBER(38,0)"}
if slices.Contains(integerTypes, strings.ToUpper(old)) && slices.Contains(integerTypes, strings.ToUpper(new)) {
return true
}
return false
},
Required: true,
ForceNew: true,
},
"statement": {
Type: schema.TypeString,
Expand All @@ -76,10 +96,11 @@ var procedureSchema = map[string]*schema.Schema{
Type: schema.TypeString,
Optional: true,
Default: "SQL",
// Suppress the diff shown if the values are equal when both compared in lower case.
DiffSuppressFunc: DiffTypes,
ValidateFunc: validation.StringInSlice(procedureLanguages, true),
Description: "Specifies the language of the stored procedure code.",
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
return strings.EqualFold(old, new)
},
ValidateFunc: validation.StringInSlice(procedureLanguages, true),
Description: "Specifies the language of the stored procedure code.",
},
"execute_as": {
Type: schema.TypeString,
Expand Down Expand Up @@ -268,10 +289,7 @@ func ReadProcedure(d *schema.ResourceData, meta interface{}) error {
return err
}
case "returns":
// Format in Snowflake DB is RETURN_TYPE(<some number>) or RETURN_TYPE
re := regexp.MustCompile(`^([A-Z0-9_]+)\s?(\([0-9]*\))?$`)
match := re.FindStringSubmatch(desc.Value.String)
if err = d.Set("return_type", match[1]); err != nil {
if err = d.Set("return_type", desc.Value.String); err != nil {
return err
}
case "language":
Expand Down
25 changes: 24 additions & 1 deletion pkg/resources/procedure_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ func TestAcc_Procedure(t *testing.T) {
resource.TestCheckResourceAttr("snowflake_procedure.test_proc_complex", "arguments.1.type", "DATE"),
resource.TestCheckResourceAttr("snowflake_procedure.test_proc_complex", "return_behavior", "IMMUTABLE"),
resource.TestCheckResourceAttr("snowflake_procedure.test_proc_complex", "null_input_behavior", "RETURNS NULL ON NULL INPUT"),

resource.TestCheckResourceAttr("snowflake_procedure.test_proc_sql", "name", procName+"_sql"),
),
},
{
Expand Down Expand Up @@ -125,5 +127,26 @@ func procedureConfig(db, schema, name string) string {
return X
EOF
}
`, db, schema, name, name, name)
resource "snowflake_procedure" "test_proc_sql" {
name = "%s_sql"
database = snowflake_database.test_database.name
schema = snowflake_schema.test_schema.name
language = "SQL"
return_type = "INTEGER"
execute_as = "CALLER"
return_behavior = "IMMUTABLE"
null_input_behavior = "RETURNS NULL ON NULL INPUT"
statement = <<EOT
declare
x integer;
y integer;
begin
x := 3;
y := x * x;
return y;
end;
EOT
}
`, db, schema, name, name, name, name)
}
3 changes: 0 additions & 3 deletions pkg/resources/procedure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ func TestProcedureCreate(t *testing.T) {
err := resources.CreateProcedure(d, db)
r.NoError(err)
r.Equal("MY_PROC", d.Get("name").(string))
r.Equal("VARCHAR", d.Get("return_type").(string))
r.Equal("mock comment", d.Get("comment").(string))
})
}
Expand Down Expand Up @@ -89,7 +88,6 @@ func TestProcedureRead(t *testing.T) {
r.Equal("MY_DB", d.Get("database").(string))
r.Equal("MY_SCHEMA", d.Get("schema").(string))
r.Equal("mock comment", d.Get("comment").(string))
r.Equal("VARCHAR", d.Get("return_type").(string))
r.Equal("SQL", d.Get("language").(string))
r.Equal("IMMUTABLE", d.Get("return_behavior").(string))
r.Equal(procedureBody, d.Get("statement").(string))
Expand All @@ -112,7 +110,6 @@ func TestProcedureRead(t *testing.T) {
err := resources.ReadProcedure(d, db)
r.NoError(err)
r.Equal("MY_PROC", d.Get("name").(string))
r.Equal("TABLE", d.Get("return_type").(string))
})
}

Expand Down

0 comments on commit c1cf881

Please sign in to comment.