diff --git a/.lgtm.yml b/.lgtm.yml new file mode 100644 index 0000000000..d8c0cfc74c --- /dev/null +++ b/.lgtm.yml @@ -0,0 +1,10 @@ + +extraction: + javascript: + index: + exclude: + - "*" + python: + index: + exclude: + - "*" diff --git a/README.md b/README.md index a4dad2bc08..3402a34091 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ | [databricks_secret_scope](docs/resources/secret_scope.md) | [databricks_sql_endpoint](docs/resources/sql_endpoint.md) | [databricks_spark_version](docs/data-sources/spark_version.md) data +| [databricks_sql_permissions](docs/resources/sql_permissions.md) | [databricks_token](docs/resources/token.md) | [databricks_user](docs/resources/user.md) | [databricks_user_instance_profile](docs/resources/user_instance_profile.md) diff --git a/access/acceptance/sql_permissions_test.go b/access/acceptance/sql_permissions_test.go new file mode 100644 index 0000000000..0db7259a9a --- /dev/null +++ b/access/acceptance/sql_permissions_test.go @@ -0,0 +1,57 @@ +package acceptance + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/databrickslabs/terraform-provider-databricks/common" + "github.com/databrickslabs/terraform-provider-databricks/compute" + "github.com/databrickslabs/terraform-provider-databricks/internal/acceptance" + "github.com/databrickslabs/terraform-provider-databricks/qa" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAccTableACL(t *testing.T) { + cloudEnv := os.Getenv("CLOUD_ENV") + if cloudEnv == "" { + t.Skip("Acceptance tests skipped unless env 'CLOUD_ENV' is set") + } + + client := common.CommonEnvironmentClient() + client.WithCommandExecutor(func(ctx context.Context, + dc *common.DatabricksClient) common.CommandExecutor { + return compute.NewCommandsAPI(ctx, dc) + }) + + shell := client.CommandExecutor(context.Background()) + clusterInfo := compute.NewTinyClusterInCommonPoolPossiblyReused() + talbeName := qa.RandomName("table_acl_") + + cr := shell.Execute(clusterInfo.ClusterID, "python", + fmt.Sprintf("spark.range(10).write.saveAsTable('%s')", + talbeName)) + require.False(t, cr.Failed(), cr.Error()) + os.Setenv("TABLE_ACL_TEST_TABLE", talbeName) + defer func() { + cr := shell.Execute(clusterInfo.ClusterID, "sql", + fmt.Sprintf("DROP TABLE %s", talbeName)) + assert.False(t, cr.Failed(), cr.Error()) + }() + + acceptance.Test(t, []acceptance.Step{ + { + Template: ` + resource "databricks_sql_permissions" "this" { + table = "{env.TABLE_ACL_TEST_TABLE}" + + privilege_assignments { + principal = "users" + privileges = ["SELECT"] + } + }`, + }, + }) +} diff --git a/access/resource_sql_permissions.go b/access/resource_sql_permissions.go new file mode 100644 index 0000000000..8df13b0c22 --- /dev/null +++ b/access/resource_sql_permissions.go @@ -0,0 +1,352 @@ +package access + +import ( + "context" + "fmt" + "log" + "strings" + + "github.com/databrickslabs/terraform-provider-databricks/common" + "github.com/databrickslabs/terraform-provider-databricks/compute" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" +) + +// https://docs.databricks.com/security/access-control/table-acls/object-privileges.html#operations-and-privileges + +// SqlPermissions defines table access control +type SqlPermissions struct { + Table string `json:"table,omitempty"` + View string `json:"view,omitempty"` + Database string `json:"database,omitempty"` + Catalog bool `json:"catalog,omitempty"` + AnyFile bool `json:"any_file,omitempty"` + AnonymousFunction bool `json:"anonymous_function,omitempty"` + ClusterID string `json:"cluster_id,omitempty" tf:"computed"` + PrivilegeAssignments []PrivilegeAssignment `json:"privilege_assignments,omitempty"` + + exec common.CommandExecutor +} + +// PrivilegeAssignment ... +type PrivilegeAssignment struct { + Principal string `json:"principal"` + Privileges []string `json:"privileges" tf:"slice_set"` +} + +func (ta *SqlPermissions) actualDatabase() string { + if ta.Database == "" { + return "default" + } + return ta.Database +} + +// typeAndKey returns ACL object type and key +func (ta *SqlPermissions) typeAndKey() (string, string) { + if ta.Table != "" { + return "TABLE", fmt.Sprintf("`%s`.`%s`", ta.actualDatabase(), ta.Table) + } + if ta.View != "" { + return "VIEW", fmt.Sprintf("`%s`.`%s`", ta.actualDatabase(), ta.View) + } + if ta.Database != "" { + return "DATABASE", ta.Database + } + if ta.Catalog { + return "CATALOG", "" + } + if ta.AnyFile { + return "ANY FILE", "" + } + if ta.AnonymousFunction { + return "ANONYMOUS FUNCTION", "" + } + return "", "" +} + +// ID returns Terraform resource ID +func (ta *SqlPermissions) ID() string { + objectType, key := ta.typeAndKey() + if objectType == "" && key == "" { + return "" + } + noBackticks := strings.ReplaceAll(key, "`", "") + return fmt.Sprintf("%s/%s", strings.ToLower(objectType), noBackticks) +} + +func loadTableACL(id string) (SqlPermissions, error) { + ta := SqlPermissions{} + split := strings.SplitN(id, "/", 2) + if len(split) != 2 { + return ta, fmt.Errorf("ID must be two elements: %s", id) + } + switch strings.ToLower(split[0]) { + case "database": + ta.Database = split[1] + case "view": + dav := strings.SplitN(split[1], ".", 2) + if len(dav) != 2 { + return ta, fmt.Errorf("view must have two elements") + } + ta.Database = dav[0] + ta.View = dav[1] + case "table": + dav := strings.SplitN(split[1], ".", 2) + if len(dav) != 2 { + return ta, fmt.Errorf("table must have two elements") + } + ta.Database = dav[0] + ta.Table = dav[1] + case "catalog": + ta.Catalog = true + case "any file": + ta.AnyFile = true + case "anonymous function": + ta.AnonymousFunction = true + default: + return ta, fmt.Errorf("illegal ID type: %s", split[0]) + } + return ta, nil +} + +func (ta *SqlPermissions) read() error { + thisType, thisKey := ta.typeAndKey() + if thisType == "" && thisKey == "" { + return fmt.Errorf("invalid ID") + } + currentGrantsOnThis := ta.exec.Execute(ta.ClusterID, "sql", fmt.Sprintf( + "SHOW GRANT ON %s %s", thisType, thisKey)) + if currentGrantsOnThis.Failed() { + failure := currentGrantsOnThis.Error() + if strings.Contains(failure, "does not exist") || + strings.Contains(failure, "RESOURCE_DOES_NOT_EXIST") { + return common.NotFound(failure) + } + return fmt.Errorf(failure) + } + // clear any previous entries + ta.PrivilegeAssignments = []PrivilegeAssignment{} + + // iterate over existing permissions over given data object + var currentPrincipal, currentAction, currentType, currentKey string + for currentGrantsOnThis.Scan(¤tPrincipal, ¤tAction, ¤tType, ¤tKey) { + if !strings.EqualFold(currentType, thisType) { + continue + } + if !strings.EqualFold(currentKey, thisKey) { + continue + } + if strings.HasPrefix(currentAction, "DENIED_") { + // DENY statements are intentionally not supported. + continue + } + if currentAction == "OWN" { + // skip table ownership definitions for now + continue + } + // find existing grants for all principals + var privileges *[]string + for i, privilegeAssignment := range ta.PrivilegeAssignments { + // correct all privileges for the same principal into a slide + if privilegeAssignment.Principal == currentPrincipal { + privileges = &ta.PrivilegeAssignments[i].Privileges + } + } + if privileges == nil { + // initialize permissions wrapper for a principal not seen + // in previous iterations + firstSeenPrincipalPermissions := PrivilegeAssignment{ + Principal: currentPrincipal, + Privileges: []string{}, + } + // point privileges to be of the newly added principal + ta.PrivilegeAssignments = append(ta.PrivilegeAssignments, firstSeenPrincipalPermissions) + privileges = &ta.PrivilegeAssignments[len(ta.PrivilegeAssignments)-1].Privileges + } + // add action for the principal on current iteration + *privileges = append(*privileges, currentAction) + } + return nil +} + +func (ta *SqlPermissions) revoke() error { + existing, err := loadTableACL(ta.ID()) + if err != nil { + return err + } + existing.exec = ta.exec + existing.ClusterID = ta.ClusterID + if err = existing.read(); err != nil { + return err + } + for _, privilegeAssignment := range existing.PrivilegeAssignments { + if err = ta.apply(func(objType, key string) string { + return fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s %s FROM `%s`", + objType, key, privilegeAssignment.Principal) + }); err != nil { + return err + } + } + return nil +} + +func (ta *SqlPermissions) enforce() (err error) { + if err = ta.revoke(); err != nil { + return err + } + for _, privilegeAssignment := range ta.PrivilegeAssignments { + if err = ta.apply(func(objType, key string) string { + privileges := strings.Join(privilegeAssignment.Privileges, ", ") + return fmt.Sprintf("GRANT %s ON %s %s TO `%s`", + privileges, objType, key, privilegeAssignment.Principal) + }); err != nil { + return err + } + } + return nil +} + +func (ta *SqlPermissions) apply(qb func(objType, key string) string) error { + objType, key := ta.typeAndKey() + if objType == "" && key == "" { + return fmt.Errorf("invalid ID") + } + sqlQuery := qb(objType, key) + log.Printf("[INFO] Executing SQL: %s", sqlQuery) + r := ta.exec.Execute(ta.ClusterID, "sql", sqlQuery) + return r.Err() +} + +func (ta *SqlPermissions) initCluster(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) (err error) { + clustersAPI := compute.NewClustersAPI(ctx, c) + if ci, ok := d.GetOk("cluster_id"); ok { + ta.ClusterID = ci.(string) + } else { + ta.ClusterID, err = ta.getOrCreateCluster(clustersAPI) + if err != nil { + return + } + } + clusterInfo, err := clustersAPI.StartAndGetInfo(ta.ClusterID) + if e, ok := err.(common.APIError); ok && e.IsMissing() { + // cluster that was previously in a tfstate was deleted + ta.ClusterID, err = ta.getOrCreateCluster(clustersAPI) + if err != nil { + return + } + clusterInfo, err = clustersAPI.StartAndGetInfo(ta.ClusterID) + } + if err != nil { + return + } + if v, ok := clusterInfo.SparkConf["spark.databricks.acl.dfAclsEnabled"]; !ok || v != "true" { + err = fmt.Errorf("cluster_id: not a High-Concurrency cluster: %s (%s)", + clusterInfo.ClusterName, clusterInfo.ClusterID) + return + } + ta.exec = c.CommandExecutor(ctx) + return nil +} + +func (ta *SqlPermissions) getOrCreateCluster(clustersAPI compute.ClustersAPI) (string, error) { + sparkVersion := clustersAPI.LatestSparkVersionOrDefault(compute.SparkVersionRequest{ + Latest: true, + }) + nodeType := clustersAPI.GetSmallestNodeType(compute.NodeTypeRequest{LocalDisk: true}) + aclCluster, err := clustersAPI.GetOrCreateRunningCluster( + "terrraform-table-acl", compute.Cluster{ + ClusterName: "terrraform-table-acl", + SparkVersion: sparkVersion, + NodeTypeID: nodeType, + AutoterminationMinutes: 10, + SparkConf: map[string]string{ + "spark.databricks.acl.dfAclsEnabled": "true", + "spark.databricks.repl.allowedLanguages": "python,sql", + "spark.databricks.cluster.profile": "serverless", + "spark.master": "local[*]", + }, + CustomTags: map[string]string{ + "ResourceClass": "SingleNode", + }, + }) + if err != nil { + return "", err + } + return aclCluster.ClusterID, nil +} + +func tableAclForUpdate(ctx context.Context, d *schema.ResourceData, + s map[string]*schema.Schema, c *common.DatabricksClient) (ta SqlPermissions, err error) { + if err = common.DataToStructPointer(d, s, &ta); err != nil { + return + } + err = ta.initCluster(ctx, d, c) + return +} + +func tableAclForLoad(ctx context.Context, d *schema.ResourceData, + s map[string]*schema.Schema, c *common.DatabricksClient) (ta SqlPermissions, err error) { + ta, err = loadTableACL(d.Id()) + if err != nil { + return + } + err = ta.initCluster(ctx, d, c) + return +} + +// ResourceSqlPermissions manages table ACLs +func ResourceSqlPermissions() *schema.Resource { + s := common.StructToSchema(SqlPermissions{}, func(s map[string]*schema.Schema) map[string]*schema.Schema { + alof := []string{"database", "table", "view", "catalog", "any_file", "anonymous_function"} + for _, field := range alof { + s[field].ForceNew = true + s[field].Optional = true + s[field].AtLeastOneOf = alof + } + s["cluster_id"].Computed = true + s["database"].Default = "default" + return s + }) + return common.Resource{ + Schema: s, + Create: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error { + ta, err := tableAclForUpdate(ctx, d, s, c) + if err != nil { + return err + } + if err = ta.enforce(); err != nil { + return err + } + d.SetId(ta.ID()) + return nil + }, + Read: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error { + ta, err := tableAclForLoad(ctx, d, s, c) + if err != nil { + return err + } + if err = ta.read(); err != nil { + return err + } + if len(ta.PrivilegeAssignments) == 0 { + // reflect resource is skipping empty privilege_assignments + d.Set("privilege_assignments", []interface{}{}) + } + return common.StructToData(ta, s, d) + }, + Update: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error { + ta, err := tableAclForUpdate(ctx, d, s, c) + if err != nil { + return err + } + return ta.enforce() + }, + Delete: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error { + ta, err := tableAclForLoad(ctx, d, s, c) + if err != nil { + return err + } + return ta.revoke() + }, + }.ToResource() +} diff --git a/access/resource_sql_permissions_test.go b/access/resource_sql_permissions_test.go new file mode 100644 index 0000000000..381dcffb69 --- /dev/null +++ b/access/resource_sql_permissions_test.go @@ -0,0 +1,408 @@ +package access + +import ( + "fmt" + "testing" + + "github.com/databrickslabs/terraform-provider-databricks/common" + "github.com/databrickslabs/terraform-provider-databricks/compute" + "github.com/databrickslabs/terraform-provider-databricks/qa" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTableACLID(t *testing.T) { + for id, ta := range map[string]SqlPermissions{ + "table/default.foo": {Table: "foo"}, + "view/bar.foo": {View: "foo", Database: "bar"}, + "database/bar": {Database: "bar"}, + "catalog/": {Catalog: true}, + "any file/": {AnyFile: true}, + "anonymous function/": {AnonymousFunction: true}, + } { + assert.Equal(t, id, ta.ID()) + ta2, err := loadTableACL(id) + if ta.Database == "" && ta2.Database == "default" { + ta.Database = "default" + } + assert.NoError(t, err, fmt.Sprintf("%v", err)) + assert.Equal(t, ta, ta2, id) + } +} + +func TestTableACLID_errors(t *testing.T) { + for id, exp := range map[string]string{ + "table": "ID must be two elements: table", + "table/beep": "table must have two elements", + "view/beep": "view must have two elements", + "vuew/beep": "illegal ID type: vuew", + } { + _, err := loadTableACL(id) + assert.EqualError(t, err, exp) + } +} + +type mockData map[string][][]string + +func (md mockData) Execute(clusterID, language, commandStr string) common.CommandResults { + data, ok := md[commandStr] + if !ok { + return common.CommandResults{ + ResultType: "error", + Summary: fmt.Sprintf("Query is not mocked: %s", commandStr), + } + } + var x []interface{} + for _, a := range data { + var y []interface{} + for _, b := range a { + y = append(y, b) + } + x = append(x, y) + } + return common.CommandResults{ + ResultType: "table", + Data: x, + } +} + +func (md mockData) toCommandMock() func(string) common.CommandResults { + return func(commandStr string) common.CommandResults { + return md.Execute("_", "cobol", commandStr) + } +} + +func TestTableACLGrants(t *testing.T) { + ta := SqlPermissions{Table: "foo", exec: mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": { + // principal, actionType, objType, objectKey + {"users", "SELECT", "database", "foo"}, + {"users", "SELECT", "table", "`default`.`foo`"}, + {"users", "READ", "table", "`default`.`foo`"}, + {"users", "SELECT", "database", "default"}, + {"interns", "DENIED_SELECT", "table", "`default`.`foo`"}, + }, + }} + err := ta.read() + assert.NoError(t, err) + assert.Len(t, ta.PrivilegeAssignments, 1) + assert.Len(t, ta.PrivilegeAssignments[0].Privileges, 2) +} + +type failedCommand string + +func (fc failedCommand) Execute(clusterID, language, commandStr string) common.CommandResults { + return common.CommandResults{ + ResultType: "error", + Summary: string(fc), + } +} + +func (fc failedCommand) toCommandMock() func(commandStr string) common.CommandResults { + return func(commandStr string) common.CommandResults { + return fc.Execute("..", "sql", commandStr) + } +} + +func TestTableACL_NotFound(t *testing.T) { + ta := SqlPermissions{Table: "foo", exec: failedCommand("Table does not exist")} + err := ta.read() + assert.EqualError(t, err, "Table does not exist") +} + +func TestTableACL_OtherError(t *testing.T) { + ta := SqlPermissions{Table: "foo", exec: failedCommand("Some error")} + err := ta.read() + assert.EqualError(t, err, "Some error") +} + +func TestTableACL_Revoke(t *testing.T) { + ta := SqlPermissions{Table: "foo", exec: mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": { + {"users", "SELECT", "database", "foo"}, + {"users", "SELECT", "table", "`default`.`foo`"}, + {"users", "READ", "table", "`default`.`foo`"}, + {"users", "SELECT", "database", "default"}, + }, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `users`": {}, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `interns`": {}, + }} + err := ta.revoke() + require.NoError(t, err) +} + +func TestTableACL_Enforce(t *testing.T) { + ta := SqlPermissions{ + Table: "foo", + PrivilegeAssignments: []PrivilegeAssignment{ + {"engineers", []string{"MODIFY", "SELECT", "READ"}}, + {"support", []string{"SELECT"}}, + }, + exec: mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": { + {"users", "SELECT", "database", "foo"}, + {"users", "SELECT", "table", "`default`.`foo`"}, + {"users", "READ", "table", "`default`.`foo`"}, + {"users", "SELECT", "database", "default"}, + {"interns", "DENIED_SELECT", "table", "`default`.`foo`"}, + {"interns", "DENIED_READ", "table", "`default`.`foo`"}, + }, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `users`": {}, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `interns`": {}, + "GRANT MODIFY, SELECT, READ ON TABLE `default`.`foo` TO `engineers`": {}, + "GRANT SELECT ON TABLE `default`.`foo` TO `support`": {}, + }, + } + err := ta.enforce() + require.NoError(t, err) +} + +var createHighConcurrencyCluster = []qa.HTTPFixture{ + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/2.0/clusters/list", + Response: map[string]interface{}{}, + }, + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/2.0/clusters/spark-versions", + Response: compute.SparkVersionsList{ + SparkVersions: []compute.SparkVersion{ + { + Version: "7.1.x-cpu-ml-scala2.12", + Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", + }, + }, + }, + }, + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/2.0/clusters/list-node-types", + Response: compute.NodeTypeList{ + NodeTypes: []compute.NodeType{ + { + NodeTypeID: "Standard_F4s", + InstanceTypeID: "Standard_F4s", + MemoryMB: 8192, + NumCores: 4, + NodeInstanceType: &compute.NodeInstanceType{ + LocalDisks: 1, + InstanceTypeID: "Standard_F4s", + LocalDiskSizeGB: 16, + LocalNVMeDisks: 0, + }, + }, + }, + }, + }, + { + Method: "POST", + ReuseRequest: true, + Resource: "/api/2.0/clusters/create", + ExpectedRequest: compute.Cluster{ + AutoterminationMinutes: 10, + ClusterName: "terrraform-table-acl", + NodeTypeID: "Standard_F4s", + SparkVersion: "7.3.x-scala2.12", + CustomTags: map[string]string{ + "ResourceClass": "SingleNode", + }, + SparkConf: map[string]string{ + "spark.databricks.acl.dfAclsEnabled": "true", + "spark.databricks.repl.allowedLanguages": "python,sql", + "spark.databricks.cluster.profile": "serverless", + "spark.master": "local[*]", + }, + }, + Response: compute.ClusterID{ + ClusterID: "bcd", + }, + }, + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/2.0/clusters/get?cluster_id=bcd", + Response: compute.ClusterInfo{ + ClusterID: "bcd", + State: "RUNNING", + SparkConf: map[string]string{ + "spark.databricks.acl.dfAclsEnabled": "true", + "spark.databricks.cluster.profile": "singleNode", + }, + }, + }, +} + +func TestResourceSqlPermissions_Read(t *testing.T) { + qa.ResourceFixture{ + CommandMock: mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": { + {"users", "SELECT", "database", "foo"}, + {"users", "SELECT", "table", "`default`.`foo`"}, + {"bob@example.com", "OWN", "table", "`default`.`foo`"}, + {"users", "READ", "table", "`default`.`foo`"}, + {"users", "SELECT", "database", "default"}, + {"interns", "DENIED_SELECT", "table", "`default`.`foo`"}, + }, + }.toCommandMock(), + Fixtures: createHighConcurrencyCluster, + Resource: ResourceSqlPermissions(), + Read: true, + New: true, + ID: "table/default.foo", + }.ApplyNoError(t) +} + +func TestResourceSqlPermissions_Read_Error(t *testing.T) { + qa.ResourceFixture{ + Resource: ResourceSqlPermissions(), + Read: true, + New: true, + ID: "something", + }.ExpectError(t, "ID must be two elements: something") +} + +func TestResourceSqlPermissions_Read_ErrorCommand(t *testing.T) { + qa.ResourceFixture{ + CommandMock: failedCommand("does not compute").toCommandMock(), + Fixtures: createHighConcurrencyCluster, + Resource: ResourceSqlPermissions(), + ID: "database/foo", + Read: true, + New: true, + }.ExpectError(t, "does not compute") +} + +func TestResourceSqlPermissions_Create(t *testing.T) { + qa.ResourceFixture{ + CommandMock: mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": { + // TODO: transform mockData into a sequence, + // as this query should return two different results, + // based on the order of execution + {"users", "SELECT", "database", "foo"}, + {"users", "SELECT", "table", "`default`.`foo`"}, + {"users", "SELECT", "database", "default"}, + {"interns", "DENIED_SELECT", "table", "`default`.`foo`"}, + }, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `users`": {}, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `interns`": {}, + "GRANT MODIFY, SELECT ON TABLE `default`.`foo` TO `serge@example.com`": {}, + }.toCommandMock(), + HCL: ` + table = "foo" + privilege_assignments { + principal = "serge@example.com" + privileges = ["SELECT", "MODIFY"] + } + `, + Fixtures: createHighConcurrencyCluster, + Resource: ResourceSqlPermissions(), + Create: true, + }.ApplyNoError(t) +} + +func TestResourceSqlPermissions_Create_Error(t *testing.T) { + qa.ResourceFixture{ + HCL: `table = "foo" + privilege_assignments { + principal = "serge@example.com" + privileges = ["SELECT", "READ", "MODIFY"] + }`, + CommandMock: failedCommand("Some error").toCommandMock(), + Fixtures: createHighConcurrencyCluster, + Resource: ResourceSqlPermissions(), + Create: true, + }.ExpectError(t, "Some error") +} + +func TestResourceSqlPermissions_Create_Error2(t *testing.T) { + qa.ResourceFixture{ + HCL: `table = "foo" + privilege_assignments { + principal = "serge@example.com" + privileges = ["SELECT", "READ", "MODIFY"] + }`, + CommandMock: func(commandStr string) common.CommandResults { + md := mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": {}, + } + if _, ok := md[commandStr]; ok { + return md.toCommandMock()(commandStr) + } + return common.CommandResults{ + ResultType: "error", + Cause: "com.x.y.z.d.Exceptions$SQLExecutionException: org.apache.spark.s...", + Summary: "Error in SQL statement: ParseException: \nAction Unknown ActionType READ cannot be granted on tab... (127 more bytes)", + } + }, + Fixtures: createHighConcurrencyCluster, + Resource: ResourceSqlPermissions(), + Create: true, + }.ExpectError(t, "Action Unknown ActionType READ cannot be granted on tab... (127 more bytes)") +} + +func TestResourceSqlPermissions_Update(t *testing.T) { + qa.ResourceFixture{ + CommandMock: mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": { + // TODO: transform mockData into a sequence, + // as this query should return two different results, + // based on the order of execution + {"users", "SELECT", "database", "foo"}, + {"users", "SELECT", "table", "`default`.`foo`"}, + {"users", "READ", "table", "`default`.`foo`"}, + {"users", "SELECT", "database", "default"}, + {"interns", "DENIED_SELECT", "table", "`default`.`foo`"}, + }, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `users`": {}, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `interns`": {}, + "GRANT READ, MODIFY, SELECT ON TABLE `default`.`foo` TO `serge@example.com`": {}, + }.toCommandMock(), + HCL: ` + table = "foo" + privilege_assignments { + principal = "serge@example.com" + privileges = ["SELECT", "READ", "MODIFY"] + } + `, + Fixtures: createHighConcurrencyCluster, + Resource: ResourceSqlPermissions(), + Update: true, + ID: "table/default.foo", + }.ApplyNoError(t) +} + +func TestResourceSqlPermissions_Delete(t *testing.T) { + qa.ResourceFixture{ + CommandMock: mockData{ + "SHOW GRANT ON TABLE `default`.`foo`": { + {"users", "SELECT", "database", "foo"}, + {"users", "SELECT", "table", "`default`.`foo`"}, + {"users", "READ", "table", "`default`.`foo`"}, + {"users", "SELECT", "database", "default"}, + {"interns", "DENIED_SELECT", "table", "`default`.`foo`"}, + }, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `users`": {}, + "REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM `interns`": {}, + }.toCommandMock(), + HCL: ` + table = "foo" + privilege_assignments { + principal = "serge@example.com" + privileges = ["SELECT", "READ", "MODIFY"] + } + `, + Fixtures: createHighConcurrencyCluster, + Resource: ResourceSqlPermissions(), + Delete: true, + ID: "table/default.foo", + }.ApplyNoError(t) +} + +func TestResourceSqlPermissions_CornerCases(t *testing.T) { + qa.ResourceCornerCases(t, ResourceSqlPermissions(), "database/foo") +} diff --git a/common/azure_auth.go b/common/azure_auth.go index 1957a66a41..f25f0cd1f7 100644 --- a/common/azure_auth.go +++ b/common/azure_auth.go @@ -45,23 +45,21 @@ type AzureAuth struct { azureManagementEndpoint string authorizer autorest.Authorizer - temporaryPat *TokenResponse + temporaryPat *tokenResponse } -// TokenRequest contains request -type TokenRequest struct { +type tokenRequest struct { LifetimeSeconds int64 `json:"lifetime_seconds,omitempty"` Comment string `json:"comment,omitempty"` } -// TokenResponse contains response -type TokenResponse struct { +type tokenResponse struct { TokenValue string `json:"token_value,omitempty"` - TokenInfo *TokenInfo `json:"token_info,omitempty"` + TokenInfo *tokenInfo `json:"token_info,omitempty"` } -// TokenInfo is a struct that contains metadata about a given token -type TokenInfo struct { +// tokenInfo is a struct that contains metadata about a given token +type tokenInfo struct { TokenID string `json:"token_id,omitempty"` CreationTime int64 `json:"creation_time,omitempty"` ExpiryTime int64 `json:"expiry_time,omitempty"` @@ -211,7 +209,7 @@ func (aa *AzureAuth) simpleAADRequestVisitor( func (aa *AzureAuth) acquirePAT( ctx context.Context, factory func(resource string) (autorest.Authorizer, error), - visitors ...func(r *http.Request, ma autorest.Authorizer) error) (*TokenResponse, error) { + visitors ...func(r *http.Request, ma autorest.Authorizer) error) (*tokenResponse, error) { if aa.temporaryPat != nil { // todo: add IsExpired return aa.temporaryPat, nil @@ -261,12 +259,12 @@ func (aa *AzureAuth) acquirePAT( return aa.temporaryPat, nil } -func (aa *AzureAuth) patRequest() TokenRequest { +func (aa *AzureAuth) patRequest() tokenRequest { seconds, err := strconv.ParseInt(aa.PATTokenDurationSeconds, 10, 64) if err != nil { seconds = 60 * 60 } - return TokenRequest{ + return tokenRequest{ LifetimeSeconds: seconds, Comment: "Secret made via Terraform", } @@ -315,7 +313,7 @@ func (aa *AzureAuth) ensureWorkspaceURL(ctx context.Context, } func (aa *AzureAuth) createPAT(ctx context.Context, - interceptor func(r *http.Request) error) (tr TokenResponse, err error) { + interceptor func(r *http.Request) error) (tr tokenResponse, err error) { log.Println("[DEBUG] Creating workspace token") url := fmt.Sprintf("%sapi/2.0/token/create", aa.databricksClient.Host) body, err := aa.databricksClient.genericQuery(ctx, diff --git a/common/azure_auth_test.go b/common/azure_auth_test.go index 32d7943239..53db49136c 100644 --- a/common/azure_auth_test.go +++ b/common/azure_auth_test.go @@ -128,7 +128,7 @@ func TestAcquirePAT_CornerCases(t *testing.T) { assert.EqualError(t, err, "DatabricksClient is not configured") aa.databricksClient = &DatabricksClient{} - aa.temporaryPat = &TokenResponse{ + aa.temporaryPat = &tokenResponse{ TokenValue: "...", } auth, rre := aa.acquirePAT(context.Background(), func(resource string) (autorest.Authorizer, error) { diff --git a/common/commands.go b/common/commands.go index 12f4b475fd..af3dfb620b 100644 --- a/common/commands.go +++ b/common/commands.go @@ -1,6 +1,25 @@ package common -import "context" +import ( + "context" + "fmt" + "html" + "regexp" + "strings" +) + +var ( + // IPython's output prefixes + outRE = regexp.MustCompile(`Out\[[\d\s]+\]:\s`) + // HTML tags + tagRE = regexp.MustCompile(`<[^>]*>`) + // just exception content without exception name + exceptionRE = regexp.MustCompile(`.*Exception:\s+(.*)`) + // execution errors resulting from http errors are sometimes hidden in these keys + executionErrorRE = regexp.MustCompile(`ExecutionError: ([\s\S]*)\n(StatusCode=[0-9]*)\n(StatusDescription=.*)\n`) + // usual error message explanation is hidden in this key + errorMessageRE = regexp.MustCompile(`ErrorMessage=(.+)\n`) +) // WithCommandMock mocks all command executions for this client func (c *DatabricksClient) WithCommandMock(mock CommandMock) { @@ -22,7 +41,7 @@ func (c *DatabricksClient) CommandExecutor(ctx context.Context) CommandExecutor } // CommandMock mocks the execution of command -type CommandMock func(commandStr string) (string, error) +type CommandMock func(commandStr string) CommandResults // CommandExecutorMock simplifies command testing type commandExecutorMock struct { @@ -30,11 +49,99 @@ type commandExecutorMock struct { } // Execute mock command with given mock function -func (c commandExecutorMock) Execute(clusterID, language, commandStr string) (string, error) { +func (c commandExecutorMock) Execute(clusterID, language, commandStr string) CommandResults { return c.mock(commandStr) } // CommandExecutor creates a spark context and executes a command and then closes context type CommandExecutor interface { - Execute(clusterID, language, commandStr string) (string, error) + Execute(clusterID, language, commandStr string) CommandResults +} + +// CommandResults captures results of a command +type CommandResults struct { + ResultType string `json:"resultType,omitempty"` + Summary string `json:"summary,omitempty"` + Cause string `json:"cause,omitempty"` + Data interface{} `json:"data,omitempty"` + Schema interface{} `json:"schema,omitempty"` + Truncated bool `json:"truncated,omitempty"` + IsJSONSchema bool `json:"isJsonSchema,omitempty"` + pos int +} + +// Failed tells if command execution failed +func (cr *CommandResults) Failed() bool { + return cr.ResultType == "error" +} + +// Text returns plain text results +func (cr *CommandResults) Text() string { + if cr.ResultType != "text" { + return "" + } + return outRE.ReplaceAllLiteralString(cr.Data.(string), "") +} + +// Err returns error type +func (cr *CommandResults) Err() error { + if !cr.Failed() { + return nil + } + return fmt.Errorf(cr.Error()) +} + +// Error returns error in a bit more friendly way +func (cr *CommandResults) Error() string { + if cr.ResultType != "error" { + return "" + } + summary := tagRE.ReplaceAllLiteralString(cr.Summary, "") + summary = html.UnescapeString(summary) + + exceptionMatches := exceptionRE.FindStringSubmatch(summary) + if len(exceptionMatches) == 2 { + summary = strings.ReplaceAll(exceptionMatches[1], "; nested exception is:", "") + summary = strings.TrimRight(summary, " ") + return summary + } + + executionErrorMatches := executionErrorRE.FindStringSubmatch(cr.Cause) + if len(executionErrorMatches) == 4 { + return strings.Join(executionErrorMatches[1:], "\n") + } + + errorMessageMatches := errorMessageRE.FindStringSubmatch(cr.Cause) + if len(errorMessageMatches) == 2 { + return errorMessageMatches[1] + } + + return summary +} + +// Scan scans for results +func (cr *CommandResults) Scan(dest ...interface{}) bool { + if cr.ResultType != "table" { + return false + } + if rows, ok := cr.Data.([]interface{}); ok { + if cr.pos >= len(rows) { + return false + } + if cols, ok := rows[cr.pos].([]interface{}); ok { + for i := range dest { + switch d := dest[i].(type) { + case *string: + *d = cols[i].(string) + case *int: + *d = cols[i].(int) + case *bool: + *d = cols[i].(bool) + } + } + cr.pos++ + return true + } + } + return false } diff --git a/common/commands_test.go b/common/commands_test.go index 7c38eb3dac..96c23ea220 100644 --- a/common/commands_test.go +++ b/common/commands_test.go @@ -16,15 +16,61 @@ func TestCommandMock(t *testing.T) { assert.NoError(t, err) called := false - c.WithCommandMock(func(commandStr string) (string, error) { + c.WithCommandMock(func(commandStr string) CommandResults { called = true assert.Equal(t, "print 1", commandStr) - return "done", nil + return CommandResults{ + ResultType: "text", + Data: "done", + } }) ctx := context.Background() - res, err := c.CommandExecutor(ctx).Execute("irrelevant", "python", "print 1") + cr := c.CommandExecutor(ctx).Execute("irrelevant", "python", "print 1") assert.Equal(t, true, called) - assert.Equal(t, "done", res) - assert.NoError(t, err, err) + assert.Equal(t, false, cr.Failed()) + assert.Equal(t, "done", cr.Text()) +} + +func TestCommandResults_Error(t *testing.T) { + cr := CommandResults{} + assert.NoError(t, cr.Err()) + cr.ResultType = "error" + assert.EqualError(t, cr.Err(), "") + + cr.Summary = "NotFoundException: Things are going wrong; nested exception is: with something" + assert.Equal(t, "Things are going wrong with something", cr.Error()) + + cr.Summary = "" + cr.Cause = "ExecutionError: \nStatusCode=400\nStatusDescription=ABC\nSomething else" + assert.Equal(t, "\nStatusCode=400\nStatusDescription=ABC", cr.Error()) + + cr.Cause = "ErrorMessage=Error was here\n" + assert.Equal(t, "Error was here", cr.Error()) + + assert.False(t, cr.Scan()) +} + +func TestCommandResults_Scan(t *testing.T) { + cr := CommandResults{ + ResultType: "table", + Data: []interface{}{ + []interface{}{"foo", 1, true}, + []interface{}{"bar", 2, false}, + }, + } + a := "" + b := 0 + c := false + assert.True(t, cr.Scan(&a, &b, &c)) + assert.Equal(t, "foo", a) + assert.Equal(t, 1, b) + assert.Equal(t, true, c) + + assert.True(t, cr.Scan(&a, &b, &c)) + assert.Equal(t, "bar", a) + assert.Equal(t, 2, b) + assert.Equal(t, false, c) + + assert.False(t, cr.Scan(&a, &b, &c)) } diff --git a/common/http.go b/common/http.go index b19cd70e15..e6d35eddea 100644 --- a/common/http.go +++ b/common/http.go @@ -443,7 +443,7 @@ func (c *DatabricksClient) genericQuery(ctx context.Context, method, requestURL headers += "\n" } } - log.Printf("[DEBUG] %s %s %s%v", method, requestURL, headers, c.redactedDump(requestBody)) + log.Printf("[DEBUG] %s %s %s%v", method, requestURL, headers, c.redactedDump(requestBody)) // lgtm[go/clear-text-logging] r, err := retryablehttp.FromRequest(request) if err != nil { diff --git a/common/reflect_resource.go b/common/reflect_resource.go index 1d1140c444..68d3d9fa20 100644 --- a/common/reflect_resource.go +++ b/common/reflect_resource.go @@ -106,6 +106,11 @@ func chooseFieldName(typeField reflect.StructField) string { return alias } jsonTag := typeField.Tag.Get("json") + // fields without JSON tags would be treated as if ignored, + // but keeping linters happy + if jsonTag == "" { + return "-" + } return strings.Split(jsonTag, ",")[0] } diff --git a/compute/commands.go b/compute/commands.go index 273e9798f6..33d46ea6e9 100644 --- a/compute/commands.go +++ b/compute/commands.go @@ -2,12 +2,8 @@ package compute import ( "context" - "errors" "fmt" - "html" "log" - "regexp" - "strings" "time" "github.com/databrickslabs/terraform-provider-databricks/common" @@ -16,19 +12,6 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" ) -var ( - // IPython's output prefixes - outRE = regexp.MustCompile(`Out\[[\d\s]+\]:\s`) - // HTML tags - tagRE = regexp.MustCompile(`<[^>]*>`) - // just exception content without exception name - exceptionRE = regexp.MustCompile(`.*Exception: (.*)`) - // execution errors resulting from http errors are sometimes hidden in these keys - executionErrorRE = regexp.MustCompile(`ExecutionError: ([\s\S]*)\n(StatusCode=[0-9]*)\n(StatusDescription=.*)\n`) - // usual error message explanation is hidden in this key - errorMessageRE = regexp.MustCompile(`ErrorMessage=(.+)\n`) -) - // NewCommandsAPI creates CommandsAPI instance from provider meta func NewCommandsAPI(ctx context.Context, m interface{}) CommandsAPI { return CommandsAPI{ @@ -45,84 +28,73 @@ type CommandsAPI struct { // Execute creates a spark context and executes a command and then closes context // Any leading whitespace is trimmed -func (a CommandsAPI) Execute(clusterID, language, commandStr string) (result string, err error) { +func (a CommandsAPI) Execute(clusterID, language, commandStr string) common.CommandResults { cluster, err := NewClustersAPI(a.context, a.client).Get(clusterID) if err != nil { - return + return common.CommandResults{ + ResultType: "error", + Summary: err.Error(), + } } if !cluster.IsRunningOrResizing() { - err = fmt.Errorf("Cluster %s has to be running or resizing, but is %s", clusterID, cluster.State) - return + return common.CommandResults{ + ResultType: "error", + Summary: fmt.Sprintf("Cluster %s has to be running or resizing, but is %s", clusterID, cluster.State), + } } commandStr = internal.TrimLeadingWhitespace(commandStr) log.Printf("[INFO] Executing %s command on %s:\n%s", language, clusterID, commandStr) context, err := a.createContext(language, clusterID) if err != nil { - return + return common.CommandResults{ + ResultType: "error", + Summary: err.Error(), + } } err = a.waitForContextReady(context, clusterID) if err != nil { - return + return common.CommandResults{ + ResultType: "error", + Summary: err.Error(), + } } commandID, err := a.createCommand(context, clusterID, language, commandStr) if err != nil { - return + return common.CommandResults{ + ResultType: "error", + Summary: err.Error(), + } } + // TODO: merge getCommand and waitForCommandFinished to "waitForCommandResults" err = a.waitForCommandFinished(commandID, context, clusterID) if err != nil { - return + return common.CommandResults{ + ResultType: "error", + Summary: err.Error(), + } } command, err := a.getCommand(commandID, context, clusterID) if err != nil { - return + return common.CommandResults{ + ResultType: "error", + Summary: err.Error(), + } } err = a.deleteContext(context, clusterID) if err != nil { - return + return common.CommandResults{ + ResultType: "error", + Summary: err.Error(), + } } - return a.parseCommandResults(command) -} - -func (a CommandsAPI) parseCommandResults(command Command) (result string, err error) { if command.Results == nil { - err = fmt.Errorf("Command has no results: %#v", command) - return - } - switch command.Results.ResultType { - case "text": - result = outRE.ReplaceAllLiteralString(command.Results.Data.(string), "") - return - case "error": - log.Printf("[DEBUG] error caused by command: %s", command.Results.Cause) - err = a.getCommandResultErrorMessage(command) - return - } - err = fmt.Errorf("Unknown result type %s: %v", command.Results.ResultType, command.Results.Data) - return -} - -func (a CommandsAPI) getCommandResultErrorMessage(command Command) error { - summary := tagRE.ReplaceAllLiteralString(command.Results.Summary, "") - summary = html.UnescapeString(summary) - - exceptionMatches := exceptionRE.FindStringSubmatch(summary) - if len(exceptionMatches) == 2 { - summary = strings.ReplaceAll(exceptionMatches[1], "; nested exception is:", "") - summary = strings.TrimRight(summary, " ") - return errors.New(summary) - } - - executionErrorMatches := executionErrorRE.FindStringSubmatch(command.Results.Cause) - if len(executionErrorMatches) == 4 { - return errors.New(strings.Join(executionErrorMatches[1:], "\n")) - } - - errorMessageMatches := errorMessageRE.FindStringSubmatch(command.Results.Cause) - if len(errorMessageMatches) == 2 { - return errors.New(errorMessageMatches[1]) + log.Printf("[ERROR] Command has no results: %#v", command) + return common.CommandResults{ + ResultType: "error", + Summary: "Command has no results", + } } - - return errors.New(summary) + return *command.Results } type genericCommandRequest struct { diff --git a/compute/commands_test.go b/compute/commands_test.go index b3ed2a6619..9ba814562b 100644 --- a/compute/commands_test.go +++ b/compute/commands_test.go @@ -83,7 +83,7 @@ func commonFixtureWithStatusResponse(response Command) []qa.HTTPFixture { func TestCommandWithExecutionError(t *testing.T) { client, server, err := qa.HttpFixtureClient(t, commonFixtureWithStatusResponse(Command{ Status: "Finished", - Results: &CommandResults{ + Results: &common.CommandResults{ ResultType: "error", Cause: ` --- @@ -99,16 +99,15 @@ StatusDescription=BadRequest ctx := context.Background() commands := NewCommandsAPI(ctx, client) - _, err = commands.Execute("abc", "python", `print("done")`) - assert.Equal(t, `An error occurred -StatusCode=400 -StatusDescription=BadRequest`, err.Error()) + result := commands.Execute("abc", "python", `print("done")`) + assert.Equal(t, true, result.Failed()) + assert.Equal(t, "An error occurred\nStatusCode=400\nStatusDescription=BadRequest", result.Error()) } func TestCommandWithEmptyErrorMessageUsesSummary(t *testing.T) { client, server, err := qa.HttpFixtureClient(t, commonFixtureWithStatusResponse(Command{ Status: "Finished", - Results: &CommandResults{ + Results: &common.CommandResults{ ResultType: "error", Cause: ` --- @@ -125,14 +124,15 @@ ErrorMessage= ctx := context.Background() commands := NewCommandsAPI(ctx, client) - _, err = commands.Execute("abc", "python", `print("done")`) - assert.Equal(t, "Proper error", err.Error()) + result := commands.Execute("abc", "python", `print("done")`) + assert.Equal(t, true, result.Failed()) + assert.Equal(t, "Proper error", result.Error()) } func TestCommandWithErrorMessage(t *testing.T) { client, server, err := qa.HttpFixtureClient(t, commonFixtureWithStatusResponse(Command{ Status: "Finished", - Results: &CommandResults{ + Results: &common.CommandResults{ ResultType: "error", Cause: ` --- @@ -147,14 +147,15 @@ ErrorMessage=An error occurred ctx := context.Background() commands := NewCommandsAPI(ctx, client) - _, err = commands.Execute("abc", "python", `print("done")`) - assert.Equal(t, "An error occurred", err.Error()) + result := commands.Execute("abc", "python", `print("done")`) + assert.Equal(t, true, result.Failed()) + assert.Equal(t, "An error occurred", result.Error()) } func TestCommandWithExceptionMessage(t *testing.T) { client, server, err := qa.HttpFixtureClient(t, commonFixtureWithStatusResponse(Command{ Status: "Finished", - Results: &CommandResults{ + Results: &common.CommandResults{ ResultType: "error", Summary: "Exception: An error occurred", }, @@ -164,14 +165,15 @@ func TestCommandWithExceptionMessage(t *testing.T) { ctx := context.Background() commands := NewCommandsAPI(ctx, client) - _, err = commands.Execute("abc", "python", `print("done")`) - assert.Equal(t, "An error occurred", err.Error()) + result := commands.Execute("abc", "python", `print("done")`) + assert.Equal(t, true, result.Failed()) + assert.Equal(t, "An error occurred", result.Error()) } func TestSomeCommands(t *testing.T) { client, server, err := qa.HttpFixtureClient(t, commonFixtureWithStatusResponse(Command{ Status: "Finished", - Results: &CommandResults{ + Results: &common.CommandResults{ ResultType: "text", Data: "done", }, @@ -181,9 +183,334 @@ func TestSomeCommands(t *testing.T) { ctx := context.Background() commands := NewCommandsAPI(ctx, client) - result, err := commands.Execute("abc", "python", `print("done")`) - require.NoError(t, err) - assert.Equal(t, "done", result) + result := commands.Execute("abc", "python", `print("done")`) + assert.Equal(t, false, result.Failed()) + assert.Equal(t, "done", result.Text()) +} + +func TestCommandsAPIExecute_FailGettingCluster(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Status: 417, + Response: common.APIError{ + Message: "Does not compute", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Does not compute") + }) +} + +func TestCommandsAPIExecute_StoppedCluster(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "TERMINATED", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Cluster abc has to be running or resizing, but is TERMINATED") + }) +} + +func TestCommandsAPIExecute_FailToCreateContext(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "RUNNING", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/create", + Status: 417, + Response: common.APIError{ + Message: "Does not compute", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Does not compute") + }) +} + +func TestCommandsAPIExecute_FailToWaitForContext(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "RUNNING", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/create", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/contexts/status?clusterId=abc&contextId=abc", + Status: 417, + Response: common.APIError{ + Message: "Does not compute", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Does not compute") + }) +} + +func TestCommandsAPIExecute_FailToCreateCommand(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "RUNNING", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/create", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/contexts/status?clusterId=abc&contextId=abc", + Response: Command{ + Status: "Running", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/commands/execute", + Status: 417, + Response: common.APIError{ + Message: "Does not compute", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Does not compute") + }) +} + +func TestCommandsAPIExecute_FailToWaitForCommand(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "RUNNING", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/create", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/contexts/status?clusterId=abc&contextId=abc", + Response: Command{ + Status: "Running", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/commands/execute", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/commands/status?clusterId=abc&commandId=abc&contextId=abc", + Status: 417, + Response: common.APIError{ + Message: "Does not compute", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Does not compute") + }) +} + +func TestCommandsAPIExecute_FailToGetCommand(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "RUNNING", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/create", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/contexts/status?clusterId=abc&contextId=abc", + Response: Command{ + Status: "Running", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/commands/execute", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/commands/status?clusterId=abc&commandId=abc&contextId=abc", + Response: Command{ + Status: "Finished", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/commands/status?clusterId=abc&commandId=abc&contextId=abc", + Status: 417, + Response: common.APIError{ + Message: "Does not compute", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Does not compute") + }) +} + +func TestCommandsAPIExecute_FailToDeleteContext(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "RUNNING", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/create", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/contexts/status?clusterId=abc&contextId=abc", + Response: Command{ + Status: "Running", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/commands/execute", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/1.2/commands/status?clusterId=abc&commandId=abc&contextId=abc", + Response: Command{ + Status: "Finished", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/destroy", + Status: 417, + Response: common.APIError{ + Message: "Does not compute", + }, + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Does not compute") + }) +} + +func TestCommandsAPIExecute_NoCommandResults(t *testing.T) { + qa.HTTPFixturesApply(t, []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.0/clusters/get?cluster_id=abc", + Response: ClusterInfo{ + State: "RUNNING", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/create", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + Resource: "/api/1.2/contexts/status?clusterId=abc&contextId=abc", + Response: Command{ + Status: "Running", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/commands/execute", + Response: Command{ + ID: "abc", + }, + }, + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/1.2/commands/status?clusterId=abc&commandId=abc&contextId=abc", + Response: Command{ + Status: "Finished", + }, + }, + { + Method: "POST", + Resource: "/api/1.2/contexts/destroy", + }, + }, func(ctx context.Context, client *common.DatabricksClient) { + commands := NewCommandsAPI(ctx, client) + cr := commands.Execute("abc", "cobol", "Hello?") + assert.EqualError(t, cr.Err(), "Command has no results") + }) } func TestAccContext(t *testing.T) { @@ -197,32 +524,26 @@ func TestAccContext(t *testing.T) { ctx := context.Background() c := NewCommandsAPI(ctx, client) - result, err := c.Execute(clusterID, "python", `print('hello world')`) - require.NoError(t, err) - assert.Equal(t, "hello world", result) + result := c.Execute(clusterID, "python", `print('hello world')`) + assert.Equal(t, "hello world", result.Text()) // exceptions are regexed away for readability - result, err = c.Execute(clusterID, "python", `raise Exception("Not Found")`) - qa.AssertErrorStartsWith(t, err, "Not Found") - assert.Equal(t, "", result) + result = c.Execute(clusterID, "python", `raise Exception("Not Found")`) + qa.AssertErrorStartsWith(t, result.Err(), "Not Found") // but errors are not - result, err = c.Execute(clusterID, "python", `raise KeyError("foo")`) - qa.AssertErrorStartsWith(t, err, "KeyError: 'foo'") - assert.Equal(t, "", result) + result = c.Execute(clusterID, "python", `raise KeyError("foo")`) + qa.AssertErrorStartsWith(t, result.Err(), "KeyError: 'foo'") // so it is more clear to read and debug - result, err = c.Execute(clusterID, "python", `return 'hello world'`) - qa.AssertErrorStartsWith(t, err, "SyntaxError: 'return' outside function") - assert.Equal(t, "", result) - - result, err = c.Execute(clusterID, "python", `"Hello World!"`) - assert.NoError(t, err) - assert.Equal(t, "'Hello World!'", result) - - result, err = c.Execute(clusterID, "python", ` - print("Hello World!") - dbutils.notebook.exit("success")`) - assert.NoError(t, err) - assert.Equal(t, "success", result) + result = c.Execute(clusterID, "python", `return 'hello world'`) + qa.AssertErrorStartsWith(t, result.Err(), "SyntaxError: 'return' outside function") + + result = c.Execute(clusterID, "python", `"Hello World!"`) + assert.Equal(t, "'Hello World!'", result.Text()) + + result = c.Execute(clusterID, "python", ` + print("Hello World!") + dbutils.notebook.exit("success")`) + assert.Equal(t, "success", result.Text()) } diff --git a/compute/model.go b/compute/model.go index 0b47f2b2db..e69de32a3b 100644 --- a/compute/model.go +++ b/compute/model.go @@ -3,6 +3,8 @@ package compute import ( "fmt" "sort" + + "github.com/databrickslabs/terraform-provider-databricks/common" ) // AutoScale is a struct the describes auto scaling for clusters @@ -366,22 +368,11 @@ type ClusterPolicyCreate struct { Definition string `json:"definition"` } -// CommandResults is the out put when the command finishes in API 1.2 -type CommandResults struct { - ResultType string `json:"resultType,omitempty"` - Summary string `json:"summary,omitempty"` - Cause string `json:"cause,omitempty"` - Data interface{} `json:"data,omitempty"` - Schema interface{} `json:"schema,omitempty"` - Truncated bool `json:"truncated,omitempty"` - IsJSONSchema bool `json:"isJsonSchema,omitempty"` -} - // Command is the struct that contains what the 1.2 api returns for the commands api type Command struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Results *CommandResults `json:"results,omitempty"` + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` + Results *common.CommandResults `json:"results,omitempty"` } // InstancePoolAwsAttributes contains aws attributes for AWS Databricks deployments for instance pools diff --git a/docs/index.md b/docs/index.md index 69d0c3209b..30523d402e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,6 +32,7 @@ Security * Manage data access with [databricks_instance_profile](resources/instance_profile.md), which can be assigned through [databricks_group_instance_profile](resources/group_instance_profile.md) and [databricks_user_instance_profile](resources/user_instance_profile.md) * Control which networks can access workspace with [databricks_ip_access_list](resources/ip_access_list.md) * Generically manage [databricks_permissions](resources/permissions.md) +* Manage data object access control lists with [databricks_sql_permissions](resources/sql_permissions.md) * Keep sensitive elements like passwords in [databricks_secret](resources/secret.md), grouped into [databricks_secret_scope](resources/secret_scope.md) and controlled by [databricks_secret_acl](resources/secret_acl.md) diff --git a/docs/resources/cluster.md b/docs/resources/cluster.md index 72ad125cef..3206161c46 100644 --- a/docs/resources/cluster.md +++ b/docs/resources/cluster.md @@ -65,7 +65,7 @@ resource "databricks_cluster" "shared_autoscaling" { min_workers = 1 max_workers = 50 } - spark_conf { + spark_conf = { "spark.databricks.io.cache.enabled": true, "spark.databricks.io.cache.maxDiskUsage": "50g", "spark.databricks.io.cache.maxMetaDataCache": "1g" diff --git a/docs/resources/permissions.md b/docs/resources/permissions.md index d61f343acf..12dcb42b4f 100644 --- a/docs/resources/permissions.md +++ b/docs/resources/permissions.md @@ -452,7 +452,7 @@ One can control access to [databricks_secret](secret.md) through `initial_manage ## Tables, Views and Databases -General Permissions API does not apply to access control for tables and permissions have to be managed separately. Though, terraform integration is coming in the future versions. +General Permissions API does not apply to access control for tables and they have to be managed separately using the [databricks_sql_permissions](sql_permissions.md) resource. ## Argument Reference diff --git a/docs/resources/sql_permissions.md b/docs/resources/sql_permissions.md new file mode 100644 index 0000000000..2310a5cf9d --- /dev/null +++ b/docs/resources/sql_permissions.md @@ -0,0 +1,93 @@ +--- +subcategory: "Security" +--- +# databricks_sql_permissions Resource + +-> **Note** This resource has an evolving API, which may change in the upcoming versions. + +This resource manages data object access control lists in Databricks workspaces for things like tables, views, databases, and [more](https://docs.databricks.com/security/access-control/table-acls/object-privileges.html). In order to enable Table Access control, you have to login to the workspace as administrator, go to `Admin Console`, pick `Access Control` tab, click on `Enable` button in `Table Access Control` section, and click `Confirm`. The security guarantees of table access control **will only be effective if cluster access control is also turned on**. Please make sure that no users can create clusters in your workspace and all [databricks_cluster](cluster.md) have approximately the following configuration: + +```hcl +resource "databricks_cluster" "cluster_with_table_access_control" { + // ... + + spark_conf = { + "spark.databricks.acl.dfAclsEnabled": "true", + "spark.databricks.repl.allowedLanguages": "sql,python,r", + "spark.databricks.cluster.profile": "serverless" + } + + custom_tags = { + "ResourceClass" = "Serverless" + } +} +``` + +## Example Usage + +The following resource definition will enforce access control on a table by executing the following SQL queries on a special auto-terminating cluster it would create for this operation: + +* ```SHOW GRANT ON TABLE `default`.`foo` ``` +* ```REVOKE ALL PRIVILEGES ON TABLE `default`.`foo` FROM ... every group and user that has access to it ...``` +* ```GRANT MODIFY, SELECT ON TABLE `default`.`foo` TO `serge@example.com` ``` +* ```GRANT SELECT ON TABLE `default`.`foo` TO `special group` ``` + +```hcl +resource "databricks_sql_permissions" "foo_table" { + table = "foo" + + privilege_assignments { + principal = "serge@example.com" + privileges = ["SELECT", "MODIFY"] + } + + privilege_assignments { + principal = "special group" + privileges = ["SELECT"] + } +} +``` + +## Argument Reference + +The following arguments are available to specify the data object you need to enforce access controls on. You must specify only one of those arguments (except for `table` and `view`), otherwise resource creation will fail. + +* `database` - Name of the database. Has default value of `default`. +* `table` - Name of the table. Can be combined with `database`. +* `view` - Name of the view. Can be combined with `database`. +* `catalog` - (Boolean) If this access control for the entire catalog. Defaults to `false`. +* `any_file` - (Boolean) If this access control for reading any file. Defaults to `false`. +* `anonymous_function` - (Boolean) If this access control for using anonymous function. Defaults to `false`. + +### `privilege_assignments` blocks + +You must specify one or many `privilege_assignments` configuration blocks to declare `privileges` to a `principal`, which corresponds to `display_name` of [databricks_group](group.md#display_name) or [databricks_user](user.md#display_name). Terraform would ensure that only those principals and privileges defined in the resource are applied for the data object and would remove anything else. It would not remove any transitive privileges. `DENY` statements are intentionally not supported. Every `privilege_assignments` has the following required arguments: + +* `principal` - `display_name` of [databricks_group](group.md#display_name) or [databricks_user](user.md#display_name). +* `privileges` - set of available privilege names in upper case. + +[Available](https://docs.databricks.com/security/access-control/table-acls/object-privileges.html) privilege names are: + +* `SELECT` - gives read access to an object. +* `CREATE` - gives the ability to create an object (for example, a table in a database). +* `MODIFY` - gives the ability to add, delete, and modify data to or from an object. +* `USAGE` - do not give any abilities, but is an additional requirement to perform any action on a database object. +* `READ_METADATA` - gives the ability to view an object and its metadata. +* `CREATE_NAMED_FUNCTION` - gives the ability to create a named UDF in an existing catalog or database. +* `MODIFY_CLASSPATH` - gives the ability to add files to the Spark class path. +* `ALL PRIVILEGES` - gives all privileges (is translated into all the above privileges). + +## Import + +The resource can be imported using a synthetic identifier. Examples of valid synthetic identifiers are: + +* `table/default.foo` - table `foo` in a `default` database. Database is always mandatory. +* `view/bar.foo` - view `foo` in `bar` database. +* `database/bar` - `bar` database. +* `catalog/` - entire catalog. `/` suffix is mandatory. +* `any file/` - direct access to any file. `/` suffix is mandatory. +* `anonymous function/` - anonymous function. `/` suffix is mandatory. + +```bash +$ terraform import databricks_sql_permissions.foo // +``` diff --git a/exporter/exporter_test.go b/exporter/exporter_test.go index 58727d34a4..b9fdb11215 100644 --- a/exporter/exporter_test.go +++ b/exporter/exporter_test.go @@ -96,7 +96,7 @@ func TestImportingMounts(t *testing.T) { Resource: "/api/1.2/commands/status?clusterId=mount&commandId=run&contextId=context", Response: compute.Command{ Status: "Finished", - Results: &compute.CommandResults{ + Results: &common.CommandResults{ ResultType: "text", Data: `{"foo": "s3a://foo", "bar": "abfss://bar@baz.com/thing", "third": "adls://foo.bar.com/path"} and some chatty messages`, diff --git a/exporter/util.go b/exporter/util.go index ae920b477c..76e0181b8c 100644 --- a/exporter/util.go +++ b/exporter/util.go @@ -236,11 +236,12 @@ println(mapper.writeValueAsString(readableMounts))` func (ic *importContext) getMountsThroughCluster( commandAPI common.CommandExecutor, clusterID string) (mm map[string]string, err error) { // Scala has actually working timeout handling, compared to Python - j, err := commandAPI.Execute(clusterID, "scala", getReadableMountsCommand) - if err != nil { + result := commandAPI.Execute(clusterID, "scala", getReadableMountsCommand) + if result.Failed() { + err = result.Err() return } - lines := strings.Split(j, "\n") + lines := strings.Split(result.Text(), "\n") err = json.Unmarshal([]byte(lines[0]), &mm) return } diff --git a/provider/provider.go b/provider/provider.go index 4888189c81..69e808f2a5 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -37,11 +37,12 @@ func DatabricksProvider() *schema.Provider { "databricks_zones": compute.DataSourceClusterZones(), }, ResourcesMap: map[string]*schema.Resource{ - "databricks_secret": access.ResourceSecret(), - "databricks_secret_scope": access.ResourceSecretScope(), - "databricks_secret_acl": access.ResourceSecretACL(), - "databricks_permissions": access.ResourcePermissions(), - "databricks_ip_access_list": access.ResourceIPAccessList(), + "databricks_secret": access.ResourceSecret(), + "databricks_secret_scope": access.ResourceSecretScope(), + "databricks_secret_acl": access.ResourceSecretACL(), + "databricks_permissions": access.ResourcePermissions(), + "databricks_sql_permissions": access.ResourceSqlPermissions(), + "databricks_ip_access_list": access.ResourceIPAccessList(), "databricks_cluster": compute.ResourceCluster(), "databricks_cluster_policy": compute.ResourceClusterPolicy(), diff --git a/qa/testing.go b/qa/testing.go index b624311bd8..6c9fb6b310 100644 --- a/qa/testing.go +++ b/qa/testing.go @@ -240,7 +240,7 @@ func (f ResourceFixture) ExpectError(t *testing.T, msg string) { } // ResourceCornerCases checks for corner cases of error handling. Optional field name used to create error -func ResourceCornerCases(t *testing.T, resource *schema.Resource) { +func ResourceCornerCases(t *testing.T, resource *schema.Resource, id ...string) { teapot := "I'm a teapot" m := map[string]func(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics{ "create": resource.CreateContext, @@ -248,6 +248,10 @@ func ResourceCornerCases(t *testing.T, resource *schema.Resource) { "update": resource.UpdateContext, "delete": resource.DeleteContext, } + fakeID := "x" + if len(id) > 0 { + fakeID = id[0] + } HTTPFixturesApply(t, []HTTPFixture{ { MatchAny: true, @@ -261,7 +265,7 @@ func ResourceCornerCases(t *testing.T, resource *schema.Resource) { }, }, func(ctx context.Context, client *common.DatabricksClient) { validData := resource.TestResourceData() - validData.SetId("x") + validData.SetId(fakeID) for n, v := range m { if v == nil { continue diff --git a/qa/testing_test.go b/qa/testing_test.go index 3578072d56..9fb53a2d99 100644 --- a/qa/testing_test.go +++ b/qa/testing_test.go @@ -163,8 +163,11 @@ func TestResourceFixture_ID(t *testing.T) { func TestResourceFixture_Apply(t *testing.T) { d, err := ResourceFixture{ - CommandMock: func(commandStr string) (string, error) { - return "yes", nil + CommandMock: func(commandStr string) common.CommandResults { + return common.CommandResults{ + ResultType: "text", + Data: "yes", + } }, Azure: true, Resource: noopResource, @@ -179,8 +182,11 @@ func TestResourceFixture_Apply(t *testing.T) { func TestResourceFixture_ApplyDelete(t *testing.T) { d, err := ResourceFixture{ - CommandMock: func(commandStr string) (string, error) { - return "yes", nil + CommandMock: func(commandStr string) common.CommandResults { + return common.CommandResults{ + ResultType: "text", + Data: "yes", + } }, Azure: true, Resource: noopContextResource, @@ -214,8 +220,11 @@ func TestResourceFixture_InstanceState(t *testing.T) { func TestResourceFixture_Apply_Fail(t *testing.T) { _, err := ResourceFixture{ - CommandMock: func(commandStr string) (string, error) { - return "yes", nil + CommandMock: func(commandStr string) common.CommandResults { + return common.CommandResults{ + ResultType: "text", + Data: "yes", + } }, Resource: noopResource, Create: true, diff --git a/storage/adls_gen1_mount_test.go b/storage/adls_gen1_mount_test.go index f18278e136..13147b4485 100644 --- a/storage/adls_gen1_mount_test.go +++ b/storage/adls_gen1_mount_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/databrickslabs/terraform-provider-databricks/common" "github.com/databrickslabs/terraform-provider-databricks/compute" "github.com/databrickslabs/terraform-provider-databricks/internal" @@ -44,7 +45,7 @@ func TestResourceAdlsGen1Mount_Create(t *testing.T) { }, }, Resource: ResourceAzureAdlsGen1Mount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) if strings.HasPrefix(trunc, "def safe_mount") { @@ -52,7 +53,10 @@ func TestResourceAdlsGen1Mount_Create(t *testing.T) { assert.Contains(t, trunc, `"fs.adl.oauth2.credential":dbutils.secrets.get("c", "d")`) } assert.Contains(t, trunc, "/mnt/this_mount") - return testS3BucketPath, nil + return common.CommandResults{ + ResultType: "text", + Data: testS3BucketPath, + } }, State: map[string]interface{}{ "cluster_id": "this_cluster", diff --git a/storage/adls_gen2_mount_test.go b/storage/adls_gen2_mount_test.go index 5fe6f97ed2..5ecfbbebed 100644 --- a/storage/adls_gen2_mount_test.go +++ b/storage/adls_gen2_mount_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/databrickslabs/terraform-provider-databricks/common" "github.com/databrickslabs/terraform-provider-databricks/compute" "github.com/databrickslabs/terraform-provider-databricks/internal" @@ -46,7 +47,7 @@ func TestResourceAdlsGen2Mount_Create(t *testing.T) { }, }, Resource: ResourceAzureAdlsGen2Mount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) if strings.HasPrefix(trunc, "def safe_mount") { @@ -54,7 +55,10 @@ func TestResourceAdlsGen2Mount_Create(t *testing.T) { assert.Contains(t, trunc, `"fs.azure.account.oauth2.client.secret":dbutils.secrets.get("c", "d")`) } assert.Contains(t, trunc, "/mnt/this_mount") - return testS3BucketPath, nil + return common.CommandResults{ + ResultType: "text", + Data: "abfss://e@test-adls-gen2.dfs.core.windows.net", + } }, State: map[string]interface{}{ "cluster_id": "this_cluster", @@ -71,5 +75,5 @@ func TestResourceAdlsGen2Mount_Create(t *testing.T) { }.Apply(t) require.NoError(t, err, err) assert.Equal(t, "this_mount", d.Id()) - assert.Equal(t, testS3BucketPath, d.Get("source")) + assert.Equal(t, "abfss://e@test-adls-gen2.dfs.core.windows.net", d.Get("source")) } diff --git a/storage/aws_s3_mount_test.go b/storage/aws_s3_mount_test.go index 00289052df..68a3bfe53e 100644 --- a/storage/aws_s3_mount_test.go +++ b/storage/aws_s3_mount_test.go @@ -2,7 +2,6 @@ package storage import ( "context" - "errors" "strings" "testing" @@ -38,7 +37,7 @@ func TestResourceAwsS3MountCreate(t *testing.T) { }, }, Resource: ResourceAWSS3Mount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) if strings.HasPrefix(trunc, "def safe_mount") { @@ -46,7 +45,10 @@ func TestResourceAwsS3MountCreate(t *testing.T) { assert.Contains(t, trunc, `{}`) // empty brackets for empty config } assert.Contains(t, trunc, "/mnt/this_mount") - return testS3BucketPath, nil + return common.CommandResults{ + ResultType: "text", + Data: testS3BucketPath, + } }, State: map[string]interface{}{ "cluster_id": "this_cluster", @@ -55,7 +57,7 @@ func TestResourceAwsS3MountCreate(t *testing.T) { }, Create: true, }.Apply(t) - require.NoError(t, err, err) // TODO: global search-replace for NoError + require.NoError(t, err, err) assert.Equal(t, "this_mount", d.Id()) assert.Equal(t, testS3BucketPath, d.Get("source")) } @@ -85,37 +87,6 @@ func TestResourceAwsS3MountCreate_invalid_arn(t *testing.T) { require.EqualError(t, err, "arn: invalid prefix") } -func TestResourceAwsS3MountCreate_Error(t *testing.T) { - d, err := qa.ResourceFixture{ - Fixtures: []qa.HTTPFixture{ - { - Method: "GET", - ReuseRequest: true, - Resource: "/api/2.0/clusters/get?cluster_id=this_cluster", - Response: compute.ClusterInfo{ - State: compute.ClusterStateRunning, - AwsAttributes: &compute.AwsAttributes{ - InstanceProfileArn: "abc", - }, - }, - }, - }, - Resource: ResourceAWSS3Mount(), - CommandMock: func(commandStr string) (string, error) { - return "", errors.New("Some error") - }, - State: map[string]interface{}{ - "cluster_id": "this_cluster", - "mount_name": "this_mount", - "s3_bucket_name": testS3BucketName, - }, - Create: true, - }.Apply(t) - require.EqualError(t, err, "Some error") - assert.Equal(t, "this_mount", d.Id()) - assert.Equal(t, "", d.Get("source")) -} - func TestResourceAwsS3MountRead(t *testing.T) { d, err := qa.ResourceFixture{ Fixtures: []qa.HTTPFixture{ @@ -132,12 +103,15 @@ func TestResourceAwsS3MountRead(t *testing.T) { }, }, Resource: ResourceAWSS3Mount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) assert.Contains(t, trunc, "dbutils.fs.mounts()") assert.Contains(t, trunc, `mount.mountPoint == "/mnt/this_mount"`) - return testS3BucketPath, nil + return common.CommandResults{ + ResultType: "text", + Data: testS3BucketPath, + } }, State: map[string]interface{}{ "cluster_id": "this_cluster", @@ -168,10 +142,13 @@ func TestResourceAwsS3MountRead_NotFound(t *testing.T) { }, }, Resource: ResourceAWSS3Mount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) - return "", errors.New("Mount not found") + return common.CommandResults{ + ResultType: "error", + Summary: "Mount not found", + } }, State: map[string]interface{}{ "cluster_id": "this_cluster", @@ -200,10 +177,13 @@ func TestResourceAwsS3MountRead_Error(t *testing.T) { }, }, Resource: ResourceAWSS3Mount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) - return "", errors.New("Some error") + return common.CommandResults{ + ResultType: "error", + Summary: "Some error", + } }, State: map[string]interface{}{ "cluster_id": "this_cluster", @@ -234,12 +214,15 @@ func TestResourceAwsS3MountDelete(t *testing.T) { }, }, Resource: ResourceAWSS3Mount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) assert.Contains(t, trunc, "/mnt/this_mount") assert.Contains(t, trunc, "dbutils.fs.unmount(mount_point)") - return "", nil + return common.CommandResults{ + ResultType: "text", + Data: "", + } }, State: map[string]interface{}{ "cluster_id": "this_cluster", diff --git a/storage/azure_blob_mount_test.go b/storage/azure_blob_mount_test.go index 3d5ebd4b2f..a24b92339e 100644 --- a/storage/azure_blob_mount_test.go +++ b/storage/azure_blob_mount_test.go @@ -1,10 +1,10 @@ package storage import ( - "errors" "strings" "testing" + "github.com/databrickslabs/terraform-provider-databricks/common" "github.com/databrickslabs/terraform-provider-databricks/compute" "github.com/databrickslabs/terraform-provider-databricks/internal" @@ -32,7 +32,7 @@ func TestResourceAzureBlobMountCreate(t *testing.T) { }, }, Resource: ResourceAzureBlobMount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) @@ -41,7 +41,10 @@ func TestResourceAzureBlobMountCreate(t *testing.T) { assert.Contains(t, trunc, `"fs.azure.account.key.f.blob.core.windows.net":dbutils.secrets.get("h", "g")`) } assert.Contains(t, trunc, "/mnt/e") - return "wasbs://c@f.blob.core.windows.net/d", nil + return common.CommandResults{ + ResultType: "text", + Data: "wasbs://c@f.blob.core.windows.net/d", + } }, State: map[string]interface{}{ "auth_type": "ACCESS_KEY", @@ -72,8 +75,11 @@ func TestResourceAzureBlobMountCreate_Error(t *testing.T) { }, }, Resource: ResourceAzureBlobMount(), - CommandMock: func(commandStr string) (string, error) { - return "", errors.New("Some error") + CommandMock: func(commandStr string) common.CommandResults { + return common.CommandResults{ + ResultType: "error", + Summary: "Some error", + } }, State: map[string]interface{}{ "auth_type": "ACCESS_KEY", @@ -104,12 +110,15 @@ func TestResourceAzureBlobMountRead(t *testing.T) { }, }, Resource: ResourceAzureBlobMount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) assert.Contains(t, trunc, "dbutils.fs.mounts()") assert.Contains(t, trunc, `mount.mountPoint == "/mnt/e"`) - return "wasbs://c@f.blob.core.windows.net/d", nil + return common.CommandResults{ + ResultType: "text", + Data: "wasbs://c@f.blob.core.windows.net/d", + } }, State: map[string]interface{}{ "auth_type": "ACCESS_KEY", @@ -141,10 +150,13 @@ func TestResourceAzureBlobMountRead_NotFound(t *testing.T) { }, }, Resource: ResourceAzureBlobMount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) - return "", errors.New("Mount not found") + return common.CommandResults{ + ResultType: "error", + Summary: "Mount not found", + } }, State: map[string]interface{}{ "auth_type": "ACCESS_KEY", @@ -174,10 +186,13 @@ func TestResourceAzureBlobMountRead_Error(t *testing.T) { }, }, Resource: ResourceAzureBlobMount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) - return "", errors.New("Some error") + return common.CommandResults{ + ResultType: "error", + Summary: "Some error", + } }, State: map[string]interface{}{ "auth_type": "ACCESS_KEY", @@ -209,11 +224,14 @@ func TestResourceAzureBlobMountDelete(t *testing.T) { }, }, Resource: ResourceAzureBlobMount(), - CommandMock: func(commandStr string) (string, error) { + CommandMock: func(commandStr string) common.CommandResults { trunc := internal.TrimLeadingWhitespace(commandStr) t.Logf("Received command:\n%s", trunc) assert.Contains(t, trunc, "dbutils.fs.unmount(mount_point)") - return "", nil + return common.CommandResults{ + ResultType: "Text", + Data: "", + } }, State: map[string]interface{}{ "auth_type": "ACCESS_KEY", diff --git a/storage/mounts.go b/storage/mounts.go index cb608a63c4..0dfbcf8e8b 100644 --- a/storage/mounts.go +++ b/storage/mounts.go @@ -31,18 +31,19 @@ type MountPoint struct { // Source returns mountpoint source func (mp MountPoint) Source() (string, error) { - return mp.exec.Execute(mp.clusterID, "python", fmt.Sprintf(` + result := mp.exec.Execute(mp.clusterID, "python", fmt.Sprintf(` dbutils.fs.refreshMounts() for mount in dbutils.fs.mounts(): if mount.mountPoint == "/mnt/%s": dbutils.notebook.exit(mount.source) raise Exception("Mount not found") `, mp.name)) + return result.Text(), result.Err() } // Delete removes mount from workspace func (mp MountPoint) Delete() error { - _, err := mp.exec.Execute(mp.clusterID, "python", fmt.Sprintf(` + result := mp.exec.Execute(mp.clusterID, "python", fmt.Sprintf(` mount_point = "/mnt/%s" dbutils.fs.unmount(mount_point) dbutils.fs.refreshMounts() @@ -51,7 +52,7 @@ func (mp MountPoint) Delete() error { raise Exception("Failed to unmount") dbutils.notebook.exit("success") `, mp.name)) - return err + return result.Err() } // Mount mounts object store on workspace @@ -81,8 +82,8 @@ func (mp MountPoint) Mount(mo Mount) (source string, err error) { mount_source = safe_mount("/mnt/%s", "%v", %s) dbutils.notebook.exit(mount_source) `, mp.name, mo.Source(), extraConfigs) - source, err = mp.exec.Execute(mp.clusterID, "python", command) - return + result := mp.exec.Execute(mp.clusterID, "python", command) + return result.Text(), result.Err() } func commonMountResource(tpl Mount, s map[string]*schema.Schema) *schema.Resource { diff --git a/storage/mounts_test.go b/storage/mounts_test.go index 0aa0d0d9d9..3e8b276944 100644 --- a/storage/mounts_test.go +++ b/storage/mounts_test.go @@ -128,10 +128,13 @@ func testMountFuncHelper(t *testing.T, mountFunc func(mp MountPoint, mount Mount var called bool - c.WithCommandMock(func(commandStr string) (s string, e error) { + c.WithCommandMock(func(commandStr string) common.CommandResults { called = true assert.Equal(t, internal.TrimLeadingWhitespace(expectedCommand), internal.TrimLeadingWhitespace(commandStr)) - return expectedCommandResp, nil + return common.CommandResults{ + ResultType: "text", + Data: expectedCommandResp, + } }) ctx := context.Background()