Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: support pause/resume for framework #47037

Merged
merged 18 commits into from
Sep 23, 2023
Merged
4 changes: 3 additions & 1 deletion disttask/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ go_test(
"framework_dynamic_dispatch_test.go",
"framework_err_handling_test.go",
"framework_ha_test.go",
"framework_pause_and_resume_test.go",
"framework_rollback_test.go",
"framework_test.go",
],
flaky = True,
race = "off",
shard_count = 29,
shard_count = 30,
deps = [
"//disttask/framework/dispatcher",
"//disttask/framework/handle",
"//disttask/framework/mock",
"//disttask/framework/mock/execute",
"//disttask/framework/proto",
Expand Down
2 changes: 1 addition & 1 deletion disttask/framework/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ go_test(
embed = [":dispatcher"],
flaky = True,
race = "off",
shard_count = 11,
shard_count = 13,
deps = [
"//disttask/framework/mock",
"//disttask/framework/proto",
Expand Down
79 changes: 79 additions & 0 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,40 @@ func (d *BaseDispatcher) scheduleTask() {
}
}
})

failpoint.Inject("pausePendingTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStatePending {
err := d.taskMgr.PauseTask(d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
d.Task.State = proto.TaskStatePausing
}
})

failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStateRunning {
err := d.taskMgr.PauseTask(d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
d.Task.State = proto.TaskStatePausing
}
})

switch d.Task.State {
case proto.TaskStateCancelling:
err = d.onCancelling()
case proto.TaskStatePausing:
err = d.onPausing()
case proto.TaskStatePaused:
err = d.onPaused()
// close the dispatcher.
if err == nil {
return
}
case proto.TaskStateResuming:
err = d.onResuming()
case proto.TaskStateReverting:
err = d.onReverting()
case proto.TaskStatePending:
Expand Down Expand Up @@ -204,6 +235,53 @@ func (d *BaseDispatcher) onCancelling() error {
return d.onErrHandlingStage(errs)
}

// handle task in pausing state, cancel all running subtasks.
func (d *BaseDispatcher) onPausing() error {
logutil.Logger(d.logCtx).Info("on pausing state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStateRunning, proto.TaskStatePending) // ywq todo remove
d.taskMgr.PrintSubtaskInfo(d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}
if cnt == 0 {
logutil.Logger(d.logCtx).Info("all running subtasks paused, update the task to paused state")
return d.updateTask(proto.TaskStatePaused, nil, RetrySQLTimes)
}
logutil.Logger(d.logCtx).Debug("on pausing state, this task keeps current state", zap.String("state", d.Task.State))
return nil
}

// handle task in paused state
func (d *BaseDispatcher) onPaused() error {
logutil.Logger(d.logCtx).Info("on paused state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
return nil
}

// TestSyncChan is used to sync the test.
var TestSyncChan = make(chan struct{})

// handle task in resuming state
func (d *BaseDispatcher) onResuming() error {
logutil.Logger(d.logCtx).Info("on resuming state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStatePaused)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}
if cnt == 0 {
// Finish the resuming process.
logutil.Logger(d.logCtx).Info("all paused tasks converted to pending state, update the task to running state")
err := d.updateTask(proto.TaskStateRunning, nil, RetrySQLTimes)
failpoint.Inject("syncAfterResume", func() {
TestSyncChan <- struct{}{}
})
return err
}

return d.taskMgr.ResumeSubtasks(d.Task.ID)
}

// handle task in reverting state, check all revert subtasks finished.
func (d *BaseDispatcher) onReverting() error {
logutil.Logger(d.logCtx).Debug("on reverting state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
Expand Down Expand Up @@ -339,6 +417,7 @@ func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subta
logutil.Logger(d.logCtx).Error("cancel task failed", zap.Error(err))
}
})

var retryable bool
for i := 0; i < retryTimes; i++ {
retryable, err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.Task, newSubTasks, prevState)
Expand Down
10 changes: 8 additions & 2 deletions disttask/framework/dispatcher/dispatcher_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,15 @@ func (dm *Manager) dispatchTaskLoop() {
}

// TODO: Consider getting these tasks, in addition to the task being worked on..
tasks, err := dm.taskMgr.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateCancelling)
tasks, err := dm.taskMgr.GetGlobalTasksInStates(
proto.TaskStatePending,
proto.TaskStateRunning,
proto.TaskStateReverting,
proto.TaskStateCancelling,
proto.TaskStateResuming,
)
if err != nil {
logutil.BgLogger().Warn("get unfinished(pending, running, reverting or cancelling) tasks failed", zap.Error(err))
logutil.BgLogger().Warn("get unfinished(pending, running, reverting, cancelling, resuming) tasks failed", zap.Error(err))
break
}

Expand Down
53 changes: 43 additions & 10 deletions disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func TestTaskFailInManager(t *testing.T) {
}, time.Second*10, time.Millisecond*300)
}

func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel bool) {
func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, isPauseAndResume bool) {
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/MockDisableDistTask", "return(true)"))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/domain/MockDisableDistTask"))
Expand Down Expand Up @@ -336,14 +336,16 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel

// test DetectTaskLoop
checkGetTaskState := func(expectedState string) {
for i := 0; i < cnt; i++ {
i := 0
for ; i < cnt; i++ {
tasks, err := mgr.GetGlobalTasksInStates(expectedState)
require.NoError(t, err)
if len(tasks) == taskCnt {
break
}
time.Sleep(time.Millisecond * 50)
}
require.Less(t, i, cnt)
}
// Test all subtasks are successful.
var err error
Expand All @@ -365,6 +367,29 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel
err = mgr.CancelGlobalTask(int64(i))
require.NoError(t, err)
}
} else if isPauseAndResume {
for i := 1; i <= taskCnt; i++ {
err = mgr.PauseTask(int64(i))
require.NoError(t, err)
}
for i := 1; i <= subtaskCnt*taskCnt; i++ {
err = mgr.UpdateSubtaskStateAndError(int64(i), proto.TaskStatePaused, nil)
require.NoError(t, err)
}
checkGetTaskState(proto.TaskStatePaused)
for i := 1; i <= taskCnt; i++ {
found, err := mgr.ResumeTask(int64(i))
require.Equal(t, true, found)
require.NoError(t, err)
}

// Mock subtasks succeed.
for i := 1; i <= subtaskCnt*taskCnt; i++ {
err = mgr.UpdateSubtaskStateAndError(int64(i), proto.TaskStateSucceed, nil)
require.NoError(t, err)
}
checkGetTaskState(proto.TaskStateSucceed)
return
} else {
// Test each task has a subtask failed.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)"))
Expand Down Expand Up @@ -401,35 +426,43 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel
}

func TestSimple(t *testing.T) {
checkDispatch(t, 1, true, false, false)
checkDispatch(t, 1, true, false, false, false)
}

func TestSimpleErrStage(t *testing.T) {
checkDispatch(t, 1, false, false, false)
checkDispatch(t, 1, false, false, false, false)
}

func TestSimpleCancel(t *testing.T) {
checkDispatch(t, 1, false, true, false)
checkDispatch(t, 1, false, true, false, false)
}

func TestSimpleSubtaskCancel(t *testing.T) {
checkDispatch(t, 1, false, false, true)
checkDispatch(t, 1, false, false, true, false)
}

func TestParallel(t *testing.T) {
checkDispatch(t, 3, true, false, false)
checkDispatch(t, 3, true, false, false, false)
}

func TestParallelErrStage(t *testing.T) {
checkDispatch(t, 3, false, false, false)
checkDispatch(t, 3, false, false, false, false)
}

func TestParallelCancel(t *testing.T) {
checkDispatch(t, 3, false, true, false)
checkDispatch(t, 3, false, true, false, false)
}

func TestParallelSubtaskCancel(t *testing.T) {
checkDispatch(t, 3, false, false, true)
checkDispatch(t, 3, false, false, true, false)
}

func TestPause(t *testing.T) {
checkDispatch(t, 1, false, false, false, true)
}

func TestParallelPause(t *testing.T) {
checkDispatch(t, 3, false, false, false, true)
}

func TestVerifyTaskStateTransform(t *testing.T) {
Expand Down
84 changes: 84 additions & 0 deletions disttask/framework/framework_pause_and_resume_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package framework_test

import (
"sync"
"testing"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/handle"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/storage"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func CheckSubtasksState(t *testing.T, taskID int64, state string, expectedCnt int64) {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
mgr.PrintSubtaskInfo(taskID)
mgr.PrintHistorySubtaskInfo(taskID)
cnt, err := mgr.GetSubtaskInStatesCnt(taskID, state)
require.NoError(t, err)
historySubTasksCnt, err := storage.GetSubtasksFromHistoryByTaskIDForTest(mgr, taskID)
require.NoError(t, err)
require.Equal(t, expectedCnt, cnt+int64(historySubTasksCnt))
}

func TestFrameworkPauseAndResume(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 3)
// 1. dispatch and pause one running task.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/pauseTaskAfterRefreshTask", "2*return(true)"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume", "return()"))
DispatchTaskAndCheckState("key1", t, &m, proto.TaskStatePaused)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/pauseTaskAfterRefreshTask"))
// 4 subtask dispatched.
require.NoError(t, handle.ResumeTask("key1"))
<-dispatcher.TestSyncChan
WaitTaskExit(t, "key1")
CheckSubtasksState(t, 1, proto.TaskStateSucceed, 4)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume"))

mgr, err := storage.GetTaskManager()
require.NoError(t, err)
errs, err := mgr.CollectSubTaskError(1)
require.NoError(t, err)
require.Empty(t, errs)

// 2. pause pending task.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/pausePendingTask", "2*return(true)"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume", "1*return()"))
DispatchTaskAndCheckState("key2", t, &m, proto.TaskStatePaused)

require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/pausePendingTask"))
// 4 subtask dispatched.
require.NoError(t, handle.ResumeTask("key2"))
<-dispatcher.TestSyncChan
WaitTaskExit(t, "key2")
CheckSubtasksState(t, 1, proto.TaskStateSucceed, 4)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/syncAfterResume"))

errs, err = mgr.CollectSubTaskError(1)
require.NoError(t, err)
require.Empty(t, errs)
distContext.Close()
}
13 changes: 9 additions & 4 deletions disttask/framework/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,27 @@ func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync.
func DispatchTask(taskKey string, t *testing.T) *proto.Task {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
taskID, err := mgr.AddNewGlobalTask(taskKey, proto.TaskTypeExample, 8, nil)
_, err = mgr.AddNewGlobalTask(taskKey, proto.TaskTypeExample, 8, nil)
require.NoError(t, err)
start := time.Now()
return WaitTaskExit(t, taskKey)
}

func WaitTaskExit(t *testing.T, taskKey string) *proto.Task {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
var task *proto.Task
start := time.Now()
for {
if time.Since(start) > 10*time.Minute {
require.FailNow(t, "timeout")
}

time.Sleep(time.Second)
task, err = mgr.GetGlobalTaskByID(taskID)
task, err = mgr.GetGlobalTaskByKey(taskKey)

require.NoError(t, err)
require.NotNil(t, task)
if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning && task.State != proto.TaskStateCancelling && task.State != proto.TaskStateReverting {
if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning && task.State != proto.TaskStateCancelling && task.State != proto.TaskStateReverting && task.State != proto.TaskStatePausing {
break
}
}
Expand Down
Loading