Skip to content

Commit

Permalink
feat: add custom oauth int (#1286)
Browse files Browse the repository at this point in the history
* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int

* add custom oauth int
  • Loading branch information
sfc-gh-swinkler authored Oct 20, 2022
1 parent 132373c commit d6397f9
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/resources/oauth_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ resource "snowflake_oauth_integration" "tableau_desktop" {
- `comment` (String) Specifies a comment for the OAuth integration.
- `enabled` (Boolean) Specifies whether this OAuth integration is enabled or disabled.
- `oauth_issue_refresh_tokens` (Boolean) Specifies whether to allow the client to exchange a refresh token for an access token when the current access token has expired.
- `oauth_redirect_uri` (String) Specifies the client URI. After a user is authenticated, the web browser is redirected to this URI.
- `oauth_refresh_token_validity` (Number) Specifies how long refresh tokens should be valid (in seconds). OAUTH_ISSUE_REFRESH_TOKENS must be set to TRUE.
- `oauth_use_secondary_roles` (String) Specifies whether default secondary roles set in the user properties are activated by default in the session being opened.

Expand Down
21 changes: 19 additions & 2 deletions pkg/resources/oauth_integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ var oauthIntegrationSchema = map[string]*schema.Schema{
Required: true,
Description: "Specifies the OAuth client type.",
ValidateFunc: validation.StringInSlice([]string{
"TABLEAU_DESKTOP", "TABLEAU_SERVER", "LOOKER",
"TABLEAU_DESKTOP", "TABLEAU_SERVER", "LOOKER", "CUSTOM",
}, false),
},
"oauth_redirect_uri": {
Type: schema.TypeString,
Optional: true,
Description: "Specifies the client URI. After a user is authenticated, the web browser is redirected to this URI.",
},
"oauth_issue_refresh_tokens": {
Type: schema.TypeBool,
Optional: true,
Expand Down Expand Up @@ -95,8 +100,11 @@ func CreateOAuthIntegration(d *schema.ResourceData, meta interface{}) error {
// Set required fields
stmt.SetRaw(`TYPE=OAUTH`)
stmt.SetString(`OAUTH_CLIENT`, d.Get("oauth_client").(string))

// Set optional fields
if _, ok := d.GetOk("oauth_redirect_uri"); ok {
stmt.SetString(`OAUTH_REDIRECT_URI`, d.Get("oauth_redirect_uri").(string))
}

if _, ok := d.GetOk("oauth_issue_refresh_tokens"); ok {
stmt.SetBool(`OAUTH_ISSUE_REFRESH_TOKENS`, d.Get("oauth_issue_refresh_tokens").(bool))
}
Expand Down Expand Up @@ -220,6 +228,10 @@ func ReadOAuthIntegration(d *schema.ResourceData, meta interface{}) error {
if err = d.Set("blocked_roles_list", blockedRolesCustom); err != nil {
return errors.Wrap(err, "unable to set blocked roles list for security integration")
}
case "OAUTH_REDIRECT_URI":
if err = d.Set("oauth_redirect_uri", v.(string)); err != nil {
return errors.Wrap(err, "unable to set OAuth redirect URI for security integration")
}
case "OAUTH_CLIENT_TYPE":
// Only used for custom OAuth clients (not supported yet)
case "OAUTH_ENFORCE_PKCE":
Expand Down Expand Up @@ -257,6 +269,11 @@ func UpdateOAuthIntegration(d *schema.ResourceData, meta interface{}) error {
stmt.SetString(`OAUTH_CLIENT`, d.Get("oauth_client").(string))
}

if d.HasChange("oauth_redirect_uri") {
runSetStatement = true
stmt.SetString(`OAUTH_REDIRECT_URI`, d.Get("oauth_redirect_uri").(string))
}

if d.HasChange("oauth_issue_refresh_tokens") {
runSetStatement = true
stmt.SetBool(`OAUTH_ISSUE_REFRESH_TOKENS`, d.Get("oauth_issue_refresh_tokens").(bool))
Expand Down
27 changes: 27 additions & 0 deletions pkg/resources/snowflake_sweeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,38 @@ func getUsersSweeper(name string) *resource.Sweeper {
}
}

func getIntegrationsSweeper(name string) *resource.Sweeper {
return &resource.Sweeper{
Name: name,
F: func(ununsed string) error {
db, err := provider.GetDatabaseHandleFromEnv()
if err != nil {
return fmt.Errorf("Error getting db handle: %w", err)
}
integrations, err := snowflake.ListIntegrations(db)
if err != nil {
return fmt.Errorf("Error listing integrations: %w", err)
}
for _, integration := range integrations {
// can only drop security integrations
if integration.IntegrationType.String == "SECURITY" {
err := snowflake.DropIntegration(db, integration.Name.String)
if err != nil {
return fmt.Errorf("Error deleting integration %q %w", integration.Name.String, err)
}
}
}
return nil
},
}
}

// Sweepers usually go along with the tests. In TF[CE]'s case everything depends on the organization,
// which means that if we delete it then all the other entities will be deleted automatically.
func init() {
resource.AddTestSweepers("wh_sweeper", getWarehousesSweeper("wh_sweeper"))
resource.AddTestSweepers("db_sweeper", getDatabaseSweepers("db_sweeper"))
resource.AddTestSweepers("role_sweeper", getRolesSweeper("role_sweeper"))
resource.AddTestSweepers("user_sweeper", getUsersSweeper("user_sweeper"))
resource.AddTestSweepers("integration_sweeper", getIntegrationsSweeper("integration_sweeper"))
}
23 changes: 23 additions & 0 deletions pkg/snowflake/oauth_integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package snowflake

import (
"database/sql"
"log"

"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
Expand Down Expand Up @@ -37,3 +38,25 @@ func ScanOAuthIntegration(row *sqlx.Row) (*oauthIntegration, error) {
r := &oauthIntegration{}
return r, errors.Wrap(row.StructScan(r), "error scanning struct")
}

func ListIntegrations(db *sql.DB) ([]oauthIntegration, error) {
rows, err := db.Query("SHOW INTEGRATIONS")
if err != nil {
return nil, err
}

defer rows.Close()

r := []oauthIntegration{}
err = sqlx.StructScan(rows, &r)
if err == sql.ErrNoRows {
log.Println("[DEBUG] no integrations found")
return nil, nil
}
return r, nil
}

func DropIntegration(db *sql.DB, name string) error {
stmt := OAuthIntegration(name).Drop()
return Exec(db, stmt)
}

0 comments on commit d6397f9

Please sign in to comment.