Skip to content

Commit

Permalink
disttask: support pause/resume for framework (#47037)
Browse files Browse the repository at this point in the history
ref #46258
  • Loading branch information
ywqzzy authored Sep 23, 2023
1 parent 9d29580 commit 34438f8
Show file tree
Hide file tree
Showing 20 changed files with 525 additions and 114 deletions.
4 changes: 3 additions & 1 deletion disttask/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ go_test(
"framework_dynamic_dispatch_test.go",
"framework_err_handling_test.go",
"framework_ha_test.go",
"framework_pause_and_resume_test.go",
"framework_rollback_test.go",
"framework_test.go",
],
flaky = True,
race = "off",
shard_count = 29,
shard_count = 30,
deps = [
"//disttask/framework/dispatcher",
"//disttask/framework/handle",
"//disttask/framework/mock",
"//disttask/framework/mock/execute",
"//disttask/framework/proto",
Expand Down
2 changes: 1 addition & 1 deletion disttask/framework/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ go_test(
embed = [":dispatcher"],
flaky = True,
race = "off",
shard_count = 11,
shard_count = 13,
deps = [
"//disttask/framework/mock",
"//disttask/framework/proto",
Expand Down
78 changes: 78 additions & 0 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,40 @@ func (d *BaseDispatcher) scheduleTask() {
}
}
})

failpoint.Inject("pausePendingTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStatePending {
_, err := d.taskMgr.PauseTask(d.Task.Key)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
d.Task.State = proto.TaskStatePausing
}
})

failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStateRunning {
_, err := d.taskMgr.PauseTask(d.Task.Key)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
d.Task.State = proto.TaskStatePausing
}
})

switch d.Task.State {
case proto.TaskStateCancelling:
err = d.onCancelling()
case proto.TaskStatePausing:
err = d.onPausing()
case proto.TaskStatePaused:
err = d.onPaused()
// close the dispatcher.
if err == nil {
return
}
case proto.TaskStateResuming:
err = d.onResuming()
case proto.TaskStateReverting:
err = d.onReverting()
case proto.TaskStatePending:
Expand Down Expand Up @@ -204,6 +235,52 @@ func (d *BaseDispatcher) onCancelling() error {
return d.onErrHandlingStage(errs)
}

// handle task in pausing state, cancel all running subtasks.
func (d *BaseDispatcher) onPausing() error {
logutil.Logger(d.logCtx).Info("on pausing state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStateRunning, proto.TaskStatePending) // ywq todo remove
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}
if cnt == 0 {
logutil.Logger(d.logCtx).Info("all running subtasks paused, update the task to paused state")
return d.updateTask(proto.TaskStatePaused, nil, RetrySQLTimes)
}
logutil.Logger(d.logCtx).Debug("on pausing state, this task keeps current state", zap.String("state", d.Task.State))
return nil
}

// handle task in paused state
func (d *BaseDispatcher) onPaused() error {
logutil.Logger(d.logCtx).Info("on paused state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
return nil
}

// TestSyncChan is used to sync the test.
var TestSyncChan = make(chan struct{})

// handle task in resuming state
func (d *BaseDispatcher) onResuming() error {
logutil.Logger(d.logCtx).Info("on resuming state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStatePaused)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}
if cnt == 0 {
// Finish the resuming process.
logutil.Logger(d.logCtx).Info("all paused tasks converted to pending state, update the task to running state")
err := d.updateTask(proto.TaskStateRunning, nil, RetrySQLTimes)
failpoint.Inject("syncAfterResume", func() {
TestSyncChan <- struct{}{}
})
return err
}

return d.taskMgr.ResumeSubtasks(d.Task.ID)
}

// handle task in reverting state, check all revert subtasks finished.
func (d *BaseDispatcher) onReverting() error {
logutil.Logger(d.logCtx).Debug("on reverting state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
Expand Down Expand Up @@ -339,6 +416,7 @@ func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subta
logutil.Logger(d.logCtx).Error("cancel task failed", zap.Error(err))
}
})

var retryable bool
for i := 0; i < retryTimes; i++ {
retryable, err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.Task, newSubTasks, prevState)
Expand Down
10 changes: 8 additions & 2 deletions disttask/framework/dispatcher/dispatcher_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,15 @@ func (dm *Manager) dispatchTaskLoop() {
}

// TODO: Consider getting these tasks, in addition to the task being worked on..
tasks, err := dm.taskMgr.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateCancelling)
tasks, err := dm.taskMgr.GetGlobalTasksInStates(
proto.TaskStatePending,
proto.TaskStateRunning,
proto.TaskStateReverting,
proto.TaskStateCancelling,
proto.TaskStateResuming,
)
if err != nil {
logutil.BgLogger().Warn("get unfinished(pending, running, reverting or cancelling) tasks failed", zap.Error(err))
logutil.BgLogger().Warn("get unfinished(pending, running, reverting, cancelling, resuming) tasks failed", zap.Error(err))
break
}

Expand Down
54 changes: 44 additions & 10 deletions disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func TestTaskFailInManager(t *testing.T) {
}, time.Second*10, time.Millisecond*300)
}

func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel bool) {
func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, isPauseAndResume bool) {
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/MockDisableDistTask", "return(true)"))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/domain/MockDisableDistTask"))
Expand Down Expand Up @@ -336,14 +336,16 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel

// test DetectTaskLoop
checkGetTaskState := func(expectedState string) {
for i := 0; i < cnt; i++ {
i := 0
for ; i < cnt; i++ {
tasks, err := mgr.GetGlobalTasksInStates(expectedState)
require.NoError(t, err)
if len(tasks) == taskCnt {
break
}
time.Sleep(time.Millisecond * 50)
}
require.Less(t, i, cnt)
}
// Test all subtasks are successful.
var err error
Expand All @@ -365,6 +367,30 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel
err = mgr.CancelGlobalTask(int64(i))
require.NoError(t, err)
}
} else if isPauseAndResume {
for i := 0; i < taskCnt; i++ {
found, err := mgr.PauseTask(fmt.Sprintf("%d", i))
require.Equal(t, true, found)
require.NoError(t, err)
}
for i := 1; i <= subtaskCnt*taskCnt; i++ {
err = mgr.UpdateSubtaskStateAndError(int64(i), proto.TaskStatePaused, nil)
require.NoError(t, err)
}
checkGetTaskState(proto.TaskStatePaused)
for i := 0; i < taskCnt; i++ {
found, err := mgr.ResumeTask(fmt.Sprintf("%d", i))
require.Equal(t, true, found)
require.NoError(t, err)
}

// Mock subtasks succeed.
for i := 1; i <= subtaskCnt*taskCnt; i++ {
err = mgr.UpdateSubtaskStateAndError(int64(i), proto.TaskStateSucceed, nil)
require.NoError(t, err)
}
checkGetTaskState(proto.TaskStateSucceed)
return
} else {
// Test each task has a subtask failed.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)"))
Expand Down Expand Up @@ -401,35 +427,43 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel
}

func TestSimple(t *testing.T) {
checkDispatch(t, 1, true, false, false)
checkDispatch(t, 1, true, false, false, false)
}

func TestSimpleErrStage(t *testing.T) {
checkDispatch(t, 1, false, false, false)
checkDispatch(t, 1, false, false, false, false)
}

func TestSimpleCancel(t *testing.T) {
checkDispatch(t, 1, false, true, false)
checkDispatch(t, 1, false, true, false, false)
}

func TestSimpleSubtaskCancel(t *testing.T) {
checkDispatch(t, 1, false, false, true)
checkDispatch(t, 1, false, false, true, false)
}

func TestParallel(t *testing.T) {
checkDispatch(t, 3, true, false, false)
checkDispatch(t, 3, true, false, false, false)
}

func TestParallelErrStage(t *testing.T) {
checkDispatch(t, 3, false, false, false)
checkDispatch(t, 3, false, false, false, false)
}

func TestParallelCancel(t *testing.T) {
checkDispatch(t, 3, false, true, false)
checkDispatch(t, 3, false, true, false, false)
}

func TestParallelSubtaskCancel(t *testing.T) {
checkDispatch(t, 3, false, false, true)
checkDispatch(t, 3, false, false, true, false)
}

func TestPause(t *testing.T) {
checkDispatch(t, 1, false, false, false, true)
}

func TestParallelPause(t *testing.T) {
checkDispatch(t, 3, false, false, false, true)
}

func TestVerifyTaskStateTransform(t *testing.T) {
Expand Down
83 changes: 83 additions & 0 deletions disttask/framework/framework_pause_and_resume_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package framework_test

import (
"sync"
"testing"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/handle"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/storage"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func CheckSubtasksState(t *testing.T, taskID int64, state string, expectedCnt int64) {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
mgr.PrintSubtaskInfo(taskID)
cnt, err := mgr.GetSubtaskInStatesCnt(taskID, state)
require.NoError(t, err)
historySubTasksCnt, err := storage.GetSubtasksFromHistoryByTaskIDForTest(mgr, taskID)
require.NoError(t, err)
require.Equal(t, expectedCnt, cnt+int64(historySubTasksCnt))
}

func TestFrameworkPauseAndResume(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 3)
// 1. dispatch and pause one running task.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/pauseTaskAfterRefreshTask", "2*return(true)"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume", "return()"))
DispatchTaskAndCheckState("key1", t, &m, proto.TaskStatePaused)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/pauseTaskAfterRefreshTask"))
// 4 subtask dispatched.
require.NoError(t, handle.ResumeTask("key1"))
<-dispatcher.TestSyncChan
WaitTaskExit(t, "key1")
CheckSubtasksState(t, 1, proto.TaskStateSucceed, 4)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume"))

mgr, err := storage.GetTaskManager()
require.NoError(t, err)
errs, err := mgr.CollectSubTaskError(1)
require.NoError(t, err)
require.Empty(t, errs)

// 2. pause pending task.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/pausePendingTask", "2*return(true)"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume", "1*return()"))
DispatchTaskAndCheckState("key2", t, &m, proto.TaskStatePaused)

require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/pausePendingTask"))
// 4 subtask dispatched.
require.NoError(t, handle.ResumeTask("key2"))
<-dispatcher.TestSyncChan
WaitTaskExit(t, "key2")
CheckSubtasksState(t, 1, proto.TaskStateSucceed, 4)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume"))

errs, err = mgr.CollectSubTaskError(1)
require.NoError(t, err)
require.Empty(t, errs)
distContext.Close()
}
13 changes: 9 additions & 4 deletions disttask/framework/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,27 @@ func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync.
func DispatchTask(taskKey string, t *testing.T) *proto.Task {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
taskID, err := mgr.AddNewGlobalTask(taskKey, proto.TaskTypeExample, 8, nil)
_, err = mgr.AddNewGlobalTask(taskKey, proto.TaskTypeExample, 8, nil)
require.NoError(t, err)
start := time.Now()
return WaitTaskExit(t, taskKey)
}

func WaitTaskExit(t *testing.T, taskKey string) *proto.Task {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
var task *proto.Task
start := time.Now()
for {
if time.Since(start) > 10*time.Minute {
require.FailNow(t, "timeout")
}

time.Sleep(time.Second)
task, err = mgr.GetGlobalTaskByID(taskID)
task, err = mgr.GetGlobalTaskByKey(taskKey)

require.NoError(t, err)
require.NotNil(t, task)
if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning && task.State != proto.TaskStateCancelling && task.State != proto.TaskStateReverting {
if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning && task.State != proto.TaskStateCancelling && task.State != proto.TaskStateReverting && task.State != proto.TaskStatePausing {
break
}
}
Expand Down
Loading

0 comments on commit 34438f8

Please sign in to comment.