Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: support pause/resume for framework #47037

Merged
merged 18 commits into from
Sep 23, 2023
Merged
4 changes: 3 additions & 1 deletion disttask/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ go_test(
"framework_dynamic_dispatch_test.go",
"framework_err_handling_test.go",
"framework_ha_test.go",
"framework_pause_and_resume_test.go",
"framework_rollback_test.go",
"framework_test.go",
],
flaky = True,
race = "off",
shard_count = 29,
shard_count = 30,
deps = [
"//disttask/framework/dispatcher",
"//disttask/framework/handle",
"//disttask/framework/mock",
"//disttask/framework/mock/execute",
"//disttask/framework/proto",
Expand Down
61 changes: 61 additions & 0 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,29 @@ func (d *BaseDispatcher) scheduleTask() {
}
}
})

failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStateRunning {
err := d.taskMgr.PauseTask(d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
}
})

switch d.Task.State {
case proto.TaskStateCancelling:
err = d.onCancelling()
case proto.TaskStatePausing:
err = d.onPausing()
case proto.TaskStatePaused:
err = d.onPaused()
// close the dispatcher.
if err == nil {
return
}
case proto.TaskStateResuming:
err = d.onResuming()
case proto.TaskStateReverting:
err = d.onReverting()
case proto.TaskStatePending:
Expand Down Expand Up @@ -193,6 +213,45 @@ func (d *BaseDispatcher) onCancelling() error {
return d.onErrHandlingStage(errs)
}

// handle task in pausing state, cancel all running subtasks.
func (d *BaseDispatcher) onPausing() error {
logutil.Logger(d.logCtx).Info("on pausing state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStateRunning)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}
if cnt == 0 {
logutil.Logger(d.logCtx).Info("all running tasks paused, update the task to paused state")
return d.updateTask(proto.TaskStatePaused, nil, RetrySQLTimes)
}
logutil.Logger(d.logCtx).Debug("on pausing state, this task keeps current state", zap.String("state", d.Task.State))
return nil
}

// handle task in paused state
func (d *BaseDispatcher) onPaused() error {
logutil.Logger(d.logCtx).Info("on paused state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
return nil
}

// handle task in resuming state
func (d *BaseDispatcher) onResuming() error {
logutil.Logger(d.logCtx).Info("on resuming state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStatePaused)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}
if cnt == 0 {
// Finish the resuming process.
logutil.Logger(d.logCtx).Info("all paused tasks finished, update the task to running state")
return d.updateTask(proto.TaskStateRunning, nil, RetrySQLTimes)
}

return d.taskMgr.ResumeAllSubtasks(d.Task.ID)
}

// handle task in reverting state, check all revert subtasks finished.
func (d *BaseDispatcher) onReverting() error {
logutil.Logger(d.logCtx).Debug("on reverting state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
Expand Down Expand Up @@ -328,6 +387,7 @@ func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subta
logutil.Logger(d.logCtx).Error("cancel task failed", zap.Error(err))
}
})

var retryable bool
for i := 0; i < retryTimes; i++ {
retryable, err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.Task, newSubTasks, prevState)
Expand Down Expand Up @@ -602,6 +662,7 @@ func (d *BaseDispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.C
}

// VerifyTaskStateTransform verifies whether the task state transform is valid.
// TODO: YWQ verify it's true.
func VerifyTaskStateTransform(from, to string) bool {
rules := map[string][]string{
proto.TaskStatePending: {
Expand Down
10 changes: 8 additions & 2 deletions disttask/framework/dispatcher/dispatcher_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,15 @@ func (dm *Manager) dispatchTaskLoop() {
}

// TODO: Consider getting these tasks, in addition to the task being worked on..
tasks, err := dm.taskMgr.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateCancelling)
tasks, err := dm.taskMgr.GetGlobalTasksInStates(
proto.TaskStatePending,
proto.TaskStateRunning,
proto.TaskStateReverting,
proto.TaskStateCancelling,
proto.TaskStateResuming,
)
if err != nil {
logutil.BgLogger().Warn("get unfinished(pending, running, reverting or cancelling) tasks failed", zap.Error(err))
logutil.BgLogger().Warn("get unfinished(pending, running, reverting, cancelling, resuming) tasks failed", zap.Error(err))
break
}

Expand Down
58 changes: 58 additions & 0 deletions disttask/framework/framework_pause_and_resume_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package framework_test

import (
"sync"
"testing"
"time"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/disttask/framework/handle"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/storage"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func CheckSubtasksState(t *testing.T, taskID int64, state string, expectedCnt int64) {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
mgr.PrintSubtaskInfo(taskID)
cnt, err := mgr.GetSubtaskInStatesCnt(taskID, state)
require.NoError(t, err)
require.Equal(t, expectedCnt, cnt)
}

func TestFrameworkPauseAndResumeBasic(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 2)

// dispatch and pause one task.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/pauseTaskAfterRefreshTask", "2*return(true)"))
DispatchTaskAndCheckState("key1", t, &m, proto.TaskStatePaused)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/pauseTaskAfterRefreshTask"))
CheckSubtasksState(t, 1, proto.TaskStatePaused, 3)

// resume one task.
handle.ResumeTask("key1")
time.Sleep(3 * time.Second)
CheckSubtasksState(t, 1, proto.TaskStateSucceed, 4)
distContext.Close()
}
32 changes: 28 additions & 4 deletions disttask/framework/handle/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,42 @@ func SubmitAndRunGlobalTask(ctx context.Context, taskKey, taskType string, concu

// CancelGlobalTask cancels a global task.
func CancelGlobalTask(taskKey string) error {
globalTaskManager, err := storage.GetTaskManager()
taskManager, err := storage.GetTaskManager()
if err != nil {
return err
}
globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey)
task, err := taskManager.GetGlobalTaskByKey(taskKey)
if err != nil {
return err
}
if globalTask == nil {
if task == nil {
logutil.BgLogger().Info("task not exist", zap.String("taskKey", taskKey))

return nil
}
return taskManager.CancelGlobalTask(task.ID)
}

// ResumeTask resumes a task.
func ResumeTask(taskKey string) error {
ywqzzy marked this conversation as resolved.
Show resolved Hide resolved
taskManager, err := storage.GetTaskManager()
if err != nil {
return err
}
task, err := taskManager.GetGlobalTaskByKey(taskKey)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check this result in line140(through found), so could we remove this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check this result in line140(through found), so could we remove this check?

I will change the resumeTask(int64) to resumeTask(string)

if err != nil {
return err
}
if task == nil {
logutil.BgLogger().Info("task not exist", zap.String("taskKey", taskKey))
return nil
}
found, err := taskManager.ResumeTask(task.ID)
if !found {
logutil.BgLogger().Info("task not resumable", zap.String("taskKey", taskKey))
return nil
}
return globalTaskManager.CancelGlobalTask(globalTask.ID)
return err
}

// RunWithRetry runs a function with retry, when retry exceed max retry time, it
Expand Down
10 changes: 6 additions & 4 deletions disttask/framework/scheduler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ type TaskTable interface {
GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error)
GetGlobalTaskByID(taskID int64) (task *proto.Task, err error)

GetSubtaskInStates(instanceID string, taskID int64, step int64, states ...interface{}) (*proto.Subtask, error)
GetSubtaskInStates(tidbID string, taskID int64, step int64, states ...interface{}) (*proto.Subtask, error)
StartManager(tidbID string, role string) error
StartSubtask(subtaskID int64) error
UpdateSubtaskStateAndError(subtaskID int64, state string, err error) error
FinishSubtask(subtaskID int64, meta []byte) error

HasSubtasksInStates(instanceID string, taskID int64, step int64, states ...interface{}) (bool, error)
UpdateErrorToSubtask(instanceID string, taskID int64, err error) error
IsSchedulerCanceled(taskID int64, instanceID string) (bool, error)
HasSubtasksInStates(tidbID string, taskID int64, step int64, states ...interface{}) (bool, error)
UpdateErrorToSubtask(tidbID string, taskID int64, err error) error
IsSchedulerCanceled(taskID int64, tidbID string) (bool, error)
PauseSubtasks(tidbID string, taskID int64) error
}

// Pool defines the interface of a pool.
Expand All @@ -50,6 +51,7 @@ type Scheduler interface {
Init(context.Context) error
Run(context.Context, *proto.Task) error
Rollback(context.Context, *proto.Task) error
Pause(context.Context, *proto.Task) error
Close()
}

Expand Down
46 changes: 42 additions & 4 deletions disttask/framework/scheduler/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (m *Manager) fetchAndHandleRunnableTasks(ctx context.Context) {
}
}

// fetchAndFastCancelTasks fetches the reverting tasks from the global task table and fast cancels them.
// fetchAndFastCancelTasks fetches the reverting/pausing tasks from the global task table and fast cancels them.
func (m *Manager) fetchAndFastCancelTasks(ctx context.Context) {
ticker := time.NewTicker(checkTime)
for {
Expand All @@ -173,6 +173,17 @@ func (m *Manager) fetchAndFastCancelTasks(ctx context.Context) {
continue
}
m.onCanceledTasks(ctx, tasks)

// cancel pausing subtasks, and mark them as paused.
pausingTasks, err := m.taskTable.GetGlobalTasksInStates(proto.TaskStatePausing)
if err != nil {
m.onError(err)
continue
}
if err := m.onPausingTasks(pausingTasks); err != nil {
m.onError(err)
continue
}
}
}
}
Expand All @@ -181,7 +192,13 @@ func (m *Manager) fetchAndFastCancelTasks(ctx context.Context) {
func (m *Manager) onRunnableTasks(ctx context.Context, tasks []*proto.Task) {
tasks = m.filterAlreadyHandlingTasks(tasks)
for _, task := range tasks {
exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, task.Step, proto.TaskStatePending, proto.TaskStateRevertPending)
// Need to poll pending/revertPending/Running subtasks.
// It's necessary to poll running tasks since manager has possibility of restart.
// Then the subtasks are in running state and need to be handled.
exist, err := m.taskTable.HasSubtasksInStates(
m.id, task.ID, task.Step,
proto.TaskStatePending, proto.TaskStateRevertPending, proto.TaskStateRunning)

ywqzzy marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
logutil.Logger(m.logCtx).Error("check subtask exist failed", zap.Error(err))
m.onError(err)
Expand Down Expand Up @@ -218,6 +235,22 @@ func (m *Manager) onCanceledTasks(_ context.Context, tasks []*proto.Task) {
}
}

// onPausingTasks pauses/cancels the pending/running tasks.
func (m *Manager) onPausingTasks(tasks []*proto.Task) error {
m.mu.RLock()
ywqzzy marked this conversation as resolved.
Show resolved Hide resolved
defer m.mu.RUnlock()
for _, task := range tasks {
logutil.Logger(m.logCtx).Info("onPausingTasks", zap.Any("task_id", task.ID))
if cancel, ok := m.mu.handlingTasks[task.ID]; ok && cancel != nil {
cancel()
}
if err := m.taskTable.PauseSubtasks(m.id, task.ID); err != nil {
return err
}
}
return nil
}

// cancelAllRunningTasks cancels all running tasks.
func (m *Manager) cancelAllRunningTasks() {
m.mu.RLock()
Expand Down Expand Up @@ -296,19 +329,24 @@ func (m *Manager) onRunnableTask(ctx context.Context, task *proto.Task) {
zap.Int64("task_id", task.ID), zap.Int64("step", task.Step), zap.String("state", task.State))
return
}
if exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, task.Step, proto.TaskStatePending, proto.TaskStateRevertPending); err != nil {

// Considering manager restart scene, scheduler needs to handle running subtasks.
if exist, err := m.taskTable.HasSubtasksInStates(
m.id, task.ID, task.Step,
proto.TaskStatePending, proto.TaskStateRevertPending, proto.TaskStateRunning); err != nil {
ywqzzy marked this conversation as resolved.
Show resolved Hide resolved
m.onError(err)
return
} else if !exist {
continue
}

switch task.State {
case proto.TaskStateRunning:
runCtx, runCancel := context.WithCancel(ctx)
m.registerCancelFunc(task.ID, runCancel)
err = scheduler.Run(runCtx, task)
runCancel()
case proto.TaskStatePausing:
err = scheduler.Pause(ctx, task)
case proto.TaskStateReverting:
err = scheduler.Rollback(ctx, task)
}
Expand Down
17 changes: 15 additions & 2 deletions disttask/framework/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,10 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error {
if err := s.getError(); err != nil {
break
}

subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, task.Step, proto.TaskStatePending)
// Considering manager restart scene, scheduler needs to handle running subtasks.
subtask, err := s.taskTable.GetSubtaskInStates(
s.id, task.ID, task.Step,
proto.TaskStatePending, proto.TaskStateRunning)
ywqzzy marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
logutil.Logger(s.logCtx).Warn("GetSubtaskInStates meets error", zap.Error(err))
continue
Expand Down Expand Up @@ -366,6 +368,17 @@ func (s *BaseScheduler) Rollback(ctx context.Context, task *proto.Task) error {
return s.getError()
}

// Pause pause the scheduler task.
func (s *BaseScheduler) Pause(_ context.Context, task *proto.Task) error {
logutil.Logger(s.logCtx).Info("scheduler pause subtasks")
// pause all running subtasks.
if err := s.taskTable.PauseSubtasks(s.id, task.ID); err != nil {
s.onError(err)
return s.getError()
}
return nil
}

// Close closes the scheduler when all the subtasks are complete.
func (*BaseScheduler) Close() {
}
Expand Down
Loading