Skip to content

Commit

Permalink
feat: Introduce shared function to suspend task roots (#2650)
Browse files Browse the repository at this point in the history
Refactor previously planned function for task suspension, which is also
necessary to push the grant ownership resource forward.

## Test Plan
<!-- detail ways in which this PR has been tested or needs to be tested
-->
* [x] acceptance tests - tested on exiting tests (had to adjust new
suspension code to mimic what was previously, otherwise I got a
Snowflake error indicating that the root tasks were not suspended and
the logic was wrong).
* [x] integration test to test new SDK function

## Reference
[CREATE
TASK](https://docs.snowflake.com/en/sql-reference/sql/create-task)
  • Loading branch information
sfc-gh-jcieslak authored Apr 3, 2024
1 parent b542b69 commit d684b5d
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 82 deletions.
115 changes: 33 additions & 82 deletions pkg/resources/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,24 +348,16 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error {
precedingTasks := make([]sdk.SchemaObjectIdentifier, 0)
for _, dep := range after {
precedingTaskId := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, dep)
rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, precedingTaskId)
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, precedingTaskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == name) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}
precedingTasks = append(precedingTasks, precedingTaskId)
}
createRequest.WithAfter(precedingTasks)
Expand All @@ -392,7 +384,7 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error {
}

func waitForTaskStart(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error {
err := resumeTask(ctx, client, id)
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithResume(sdk.Bool(true)))
if err != nil {
return fmt.Errorf("error starting task %s err = %w", id.FullyQualifiedName(), err)
}
Expand All @@ -408,47 +400,22 @@ func waitForTaskStart(ctx context.Context, client *sdk.Client, id sdk.SchemaObje
})
}

func suspendTask(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error {
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithSuspend(sdk.Bool(true)))
if err != nil {
log.Printf("[WARN] failed to suspend task %s", id.FullyQualifiedName())
}
return err
}

func resumeTask(ctx context.Context, client *sdk.Client, id sdk.SchemaObjectIdentifier) error {
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(id).WithResume(sdk.Bool(true)))
if err != nil {
log.Printf("[WARN] failed to resume task %s", id.FullyQualifiedName())
}
return err
}

// UpdateTask implements schema.UpdateFunc.
func UpdateTask(d *schema.ResourceData, meta interface{}) error {
client := meta.(*provider.Context).Client
ctx := context.Background()

taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, taskId)
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == taskId.Name()) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}

if d.HasChange("warehouse") {
newWarehouse := d.Get("warehouse")
Expand Down Expand Up @@ -497,7 +464,9 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {

if d.HasChange("after") {
// making changes to after require suspending the current task
if err := suspendTask(ctx, client, taskId); err != nil {
// (the task will be brought up to the correct running state in the "enabled" check at the bottom of Update function).
err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSuspend(sdk.Bool(true)))
if err != nil {
return fmt.Errorf("error suspending task %s, err: %w", taskId.FullyQualifiedName(), err)
}

Expand Down Expand Up @@ -532,29 +501,19 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
toAdd = append(toAdd, sdk.NewSchemaObjectIdentifier(taskId.DatabaseName(), taskId.SchemaName(), dep))
}
}
// TODO [SNOW-1007541]: for now leaving old copy-pasted implementation; extract function for task suspension in following change
if len(toAdd) > 0 {
// need to suspend any new root tasks from dependencies before adding them
for _, dep := range toAdd {
rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, dep)
for _, depId := range toAdd {
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, depId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == taskId.Name()) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}
}

if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithAddAfter(toAdd)); err != nil {
return fmt.Errorf("error adding after dependencies from task %s", taskId.FullyQualifiedName())
}
Expand Down Expand Up @@ -702,10 +661,11 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
log.Printf("[WARN] failed to resume task %s", taskId.FullyQualifiedName())
}
} else {
if suspendTask(ctx, client, taskId) != nil {
return fmt.Errorf("[WARN] failed to suspend task %s", taskId.FullyQualifiedName())
if err := client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(taskId).WithSuspend(sdk.Bool(true))); err != nil {
return fmt.Errorf("failed to suspend task %s", taskId.FullyQualifiedName())
}
}

return ReadTask(d, meta)
}

Expand All @@ -716,24 +676,15 @@ func DeleteTask(d *schema.ResourceData, meta interface{}) error {

taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

rootTasks, err := sdk.GetRootTasks(client.Tasks, ctx, taskId)
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := suspendTask(ctx, client, rootTask.ID())
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == taskId.Name()) {
defer func(identifier sdk.SchemaObjectIdentifier) { _ = resumeTask(ctx, client, identifier) }(rootTask.ID())
}
}
}

dropRequest := sdk.NewDropTaskRequest(taskId)
err = client.Tasks.Drop(ctx, dropRequest)
Expand Down
2 changes: 2 additions & 0 deletions pkg/sdk/tasks_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type Tasks interface {
ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Task, error)
Describe(ctx context.Context, id SchemaObjectIdentifier) (*Task, error)
Execute(ctx context.Context, request *ExecuteTaskRequest) error
SuspendRootTasks(ctx context.Context, taskId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error)
ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error
}

// CreateTaskOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-task.
Expand Down
44 changes: 44 additions & 0 deletions pkg/sdk/tasks_impl_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sdk
import (
"context"
"encoding/json"
"errors"
"log"
"slices"
"strings"

Expand Down Expand Up @@ -65,6 +67,48 @@ func (v *tasks) Execute(ctx context.Context, request *ExecuteTaskRequest) error
return validateAndExec(v.client, ctx, opts)
}

// TODO(SNOW-1277135): See if depId is necessary or could be removed
func (v *tasks) SuspendRootTasks(ctx context.Context, taskId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error) {
rootTasks, err := GetRootTasks(v.client.Tasks, ctx, taskId)
if err != nil {
return nil, err
}

tasksToResume := make([]SchemaObjectIdentifier, 0)
suspendErrs := make([]error, 0)

for _, rootTask := range rootTasks {
// If a root task is started, then it needs to be suspended before the child tasks can be created
if rootTask.IsStarted() {
err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(rootTask.ID()).WithSuspend(Bool(true)))
if err != nil {
log.Printf("[WARN] failed to suspend task %s", rootTask.ID().FullyQualifiedName())
suspendErrs = append(suspendErrs, err)
}

// Resume the task after modifications are complete as long as it is not a standalone task
// TODO(SNOW-1277135): Document the purpose of this check and why we need different value for GetRootTasks (depId).
if rootTask.Name != id.Name() {
tasksToResume = append(tasksToResume, rootTask.ID())
}
}
}

return tasksToResume, errors.Join(suspendErrs...)
}

func (v *tasks) ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error {
resumeErrs := make([]error, 0)
for _, id := range ids {
err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(id).WithResume(Bool(true)))
if err != nil {
log.Printf("[WARN] failed to resume task %s", id.FullyQualifiedName())
resumeErrs = append(resumeErrs, err)
}
}
return errors.Join(resumeErrs...)
}

// GetRootTasks is a way to get all root tasks for the given tasks.
// Snowflake does not have (yet) a method to do it without traversing the task graph manually.
// Task DAG should have a single root but this is checked when the root task is being resumed; that's why we return here multiple roots.
Expand Down
131 changes: 131 additions & 0 deletions pkg/sdk/testint/tasks_gen_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,4 +592,135 @@ func TestInt_Tasks(t *testing.T) {
err := client.Tasks.Execute(ctx, executeRequest)
require.NoError(t, err)
})

t.Run("temporarily suspend root tasks", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
task := createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)
})

t.Run("resume root tasks within a graph containing more than one root task", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

secondRootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
secondRootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(secondRootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
_ = createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID(), secondRootTask.ID()}))

require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)")
require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(secondRootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)")
})

t.Run("suspend root tasks temporarily with three sequentially connected tasks - last in DAG", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
task := createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithSuspend(sdk.Bool(true))))
})

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)
require.Contains(t, tasksToResume, rootTask.ID())

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

middleTaskStatus, err := client.Tasks.ShowByID(ctx, middleTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State)

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)

middleTaskStatus, err = client.Tasks.ShowByID(ctx, middleTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State)
})

t.Run("suspend root tasks temporarily with three sequentially connected tasks - middle in DAG", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

childTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
childTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(childTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithSuspend(sdk.Bool(true))))
})

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, middleTask.ID(), middleTask.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

childTaskStatus, err := client.Tasks.ShowByID(ctx, childTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State)

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)

childTaskStatus, err = client.Tasks.ShowByID(ctx, childTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State)
})

// TODO(SNOW-1277135): Create more tests with different sets of roots/children and see if the current implementation
// acts correctly in certain situations/edge cases.
}

0 comments on commit d684b5d

Please sign in to comment.