Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: refine dispatcher #45460

Merged
merged 24 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion disttask/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ go_test(
],
flaky = True,
race = "on",
shard_count = 9,
shard_count = 10,
deps = [
"//disttask/framework/dispatcher",
"//disttask/framework/proto",
Expand Down
101 changes: 58 additions & 43 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@
retrySQLInterval = 500 * time.Millisecond
)

// Dispatch defines the interface for operations inside a dispatcher.
type Dispatch interface {
// Dispatcher defines the interface for operations inside a dispatcher.
type Dispatcher interface {
// Start enables dispatching and monitoring mechanisms.
Start()
// GetAllSchedulerIDs gets handles the task's all available instances.
GetAllSchedulerIDs(ctx context.Context, handle TaskFlowHandle, gTask *proto.Task) ([]string, error)
// Stop stops the dispatcher.
Stop()
// Inited check if the dispatcher Started.
Inited() bool
}

// TaskHandle provides the interface for operations needed by task flow handles.
Expand All @@ -80,7 +82,6 @@
d.runningGTasks.Lock()
d.runningGTasks.taskIDs[gTask.ID] = struct{}{}
d.runningGTasks.Unlock()
d.detectPendingGTaskCh <- gTask
}

func (d *dispatcher) isRunningGTask(globalTaskID int64) bool {
Expand All @@ -96,12 +97,15 @@
delete(d.runningGTasks.taskIDs, globalTaskID)
}

// dispatcher dispatch and monitor tasks.
// The monitoring task number is limited by size of gPool.
type dispatcher struct {
ctx context.Context
cancel context.CancelFunc
taskMgr *storage.TaskManager
wg tidbutil.WaitGroupWrapper
gPool *spool.Pool
inited bool

runningGTasks struct {
syncutil.RWMutex
Expand All @@ -111,35 +115,43 @@
}

// NewDispatcher creates a dispatcher struct.
func NewDispatcher(ctx context.Context, taskTable *storage.TaskManager) (Dispatch, error) {
func NewDispatcher(ctx context.Context, taskTable *storage.TaskManager) (Dispatcher, error) {
dispatcher := &dispatcher{
taskMgr: taskTable,
detectPendingGTaskCh: make(chan *proto.Task, DefaultDispatchConcurrency),
}
pool, err := spool.NewPool("dispatch_pool", int32(DefaultDispatchConcurrency), util.DistTask, spool.WithBlocking(true))
gPool, err := spool.NewPool("dispatch_pool", int32(DefaultDispatchConcurrency), util.DistTask, spool.WithBlocking(true))
if err != nil {
return nil, err
}
dispatcher.gPool = pool
dispatcher.gPool = gPool
dispatcher.ctx, dispatcher.cancel = context.WithCancel(ctx)
dispatcher.runningGTasks.taskIDs = make(map[int64]struct{})

return dispatcher, nil
}

// Start implements Dispatch.Start interface.
// Start implements dispatcher.Start interface.
func (d *dispatcher) Start() {
d.wg.Run(d.DispatchTaskLoop)
d.wg.Run(d.DetectTaskLoop)
d.inited = true
}

// Stop implements Dispatch.Stop interface.
// Stop implements dispatcher.Stop interface.
func (d *dispatcher) Stop() {
d.cancel()
d.gPool.ReleaseAndWait()
d.wg.Wait()
d.inited = false
}

func (d *dispatcher) Inited() bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use this function. Could we remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Inited is better for checking of the dispatcher is started.

return d.inited
}

// MockOwnerChange mock owner change in tests.
var MockOwnerChange func()

// DispatchTaskLoop dispatches the global tasks.
func (d *dispatcher) DispatchTaskLoop() {
logutil.BgLogger().Info("dispatch task loop start")
Expand Down Expand Up @@ -175,6 +187,7 @@
// the task is not in runningGTasks set when:
// owner changed or task is cancelled when status is pending.
if gTask.State == proto.TaskStateRunning || gTask.State == proto.TaskStateReverting || gTask.State == proto.TaskStateCancelling {
d.executeTask(gTask)
d.setRunningGTask(gTask)
cnt++
continue
Expand All @@ -184,20 +197,26 @@
break
}

err = d.processNormalFlow(gTask)
logutil.BgLogger().Info("dispatch task loop", zap.Int64("task ID", gTask.ID),
zap.String("state", gTask.State), zap.Uint64("concurrency", gTask.Concurrency), zap.Error(err))
if err != nil || gTask.IsFinished() {
continue
}
d.executeTask(gTask)
d.setRunningGTask(gTask)
cnt++
}
}
}
}

func (d *dispatcher) probeTask(taskID int64) (gTask *proto.Task, finished bool, subTaskErrs []error) {
func (d *dispatcher) executeTask(gTask *proto.Task) {
// Using the pool with block, so it wouldn't return an error.
_ = d.gPool.Run(func() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about move to dispatcher manager?

func (dm *DispatcherManager) StartDispathcer(){
      dm.gPool.Run(func() {
             d := newDispatcher()
             d.ExecuteTask()
             d.delRunningTask()
      })
} 

Then we no need finished channel.

logutil.BgLogger().Info("execute one task", zap.Int64("task ID", gTask.ID),
zap.String("state", gTask.State), zap.Uint64("concurrency", gTask.Concurrency))
d.scheduleTask(gTask.ID)
})
}

// monitorTask checks whether the current step of one task is finished,
// and gather subTaskErrs to handle subTask fails.
func (d *dispatcher) monitorTask(taskID int64) (gTask *proto.Task, finished bool, subTaskErrs []error) {
// TODO: Consider putting the following operations into a transaction.
gTask, err := d.taskMgr.GetGlobalTaskByID(taskID)
if err != nil {
Expand Down Expand Up @@ -233,40 +252,25 @@
}
}

// DetectTaskLoop monitors the status of the subtasks and processes them.
func (d *dispatcher) DetectTaskLoop() {
logutil.BgLogger().Info("detect task loop start")
for {
select {
case <-d.ctx.Done():
logutil.BgLogger().Info("detect task loop exits", zap.Error(d.ctx.Err()))
return
case task := <-d.detectPendingGTaskCh:
// Using the pool with block, so it wouldn't return an error.
_ = d.gPool.Run(func() { d.detectTask(task.ID) })
}
}
}

func (d *dispatcher) detectTask(taskID int64) {
// scheduleTask schedule the task execution step by step.
func (d *dispatcher) scheduleTask(taskID int64) {
ticker := time.NewTicker(checkTaskFinishedInterval)
defer ticker.Stop()

for {
select {
case <-d.ctx.Done():
logutil.BgLogger().Info("detect task exits", zap.Int64("task ID", taskID), zap.Error(d.ctx.Err()))
logutil.BgLogger().Info("schedule task exits", zap.Int64("task ID", taskID), zap.Error(d.ctx.Err()))
return
case <-ticker.C:
failpoint.Inject("cancelTaskBeforeProbe", func(val failpoint.Value) {
if val.(bool) {
gTask, stepIsFinished, errs := d.monitorTask(taskID)
failpoint.Inject("cancelTaskAfterMonitorTask", func(val failpoint.Value) {
if val.(bool) && gTask.State == proto.TaskStateRunning {
err := d.taskMgr.CancelGlobalTask(taskID)
if err != nil {
logutil.BgLogger().Error("cancel global task failed", zap.Error(err))
logutil.BgLogger().Error("cancel task failed", zap.Error(err))

Check warning on line 270 in disttask/framework/dispatcher/dispatcher.go

View check run for this annotation

Codecov / codecov/patch

disttask/framework/dispatcher/dispatcher.go#L270

Added line #L270 was not covered by tests
}
}
})
gTask, stepIsFinished, errs := d.probeTask(taskID)
// The global task isn't finished and not failed.
if !stepIsFinished && len(errs) == 0 {
GetTaskFlowHandle(gTask.Type).OnTicker(d.ctx, gTask)
Expand All @@ -287,6 +291,14 @@
zap.Int64("task-id", gTask.ID), zap.String("state", gTask.State))
}
}

failpoint.Inject("mockOwnerChange", func(val failpoint.Value) {
if val.(bool) {
logutil.BgLogger().Info("mockOwnerChange called")
MockOwnerChange()
time.Sleep(time.Second)
}
})
}
}

Expand Down Expand Up @@ -335,7 +347,7 @@
// 1. generate the needed global task meta and subTask meta (dist-plan).
handle := GetTaskFlowHandle(gTask.Type)
if handle == nil {
logutil.BgLogger().Warn("gen gTask flow handle failed, this type handle doesn't register", zap.Int64("ID", gTask.ID), zap.String("type", gTask.Type))
logutil.BgLogger().Warn("gen task flow handle failed, this type handle doesn't register", zap.Int64("ID", gTask.ID), zap.String("type", gTask.Type))

Check warning on line 350 in disttask/framework/dispatcher/dispatcher.go

View check run for this annotation

Codecov / codecov/patch

disttask/framework/dispatcher/dispatcher.go#L350

Added line #L350 was not covered by tests
return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes)
}
meta, err := handle.ProcessErrFlow(d.ctx, d, gTask, receiveErr)
Expand All @@ -351,7 +363,7 @@
func (d *dispatcher) dispatchSubTask4Revert(gTask *proto.Task, handle TaskFlowHandle, meta []byte) error {
instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, handle, gTask)
if err != nil {
logutil.BgLogger().Warn("get global task's all instances failed", zap.Error(err))
logutil.BgLogger().Warn("get task's all instances failed", zap.Error(err))

Check warning on line 366 in disttask/framework/dispatcher/dispatcher.go

View check run for this annotation

Codecov / codecov/patch

disttask/framework/dispatcher/dispatcher.go#L366

Added line #L366 was not covered by tests
return err
}

Expand All @@ -370,7 +382,7 @@
// 1. generate the needed global task meta and subTask meta (dist-plan).
handle := GetTaskFlowHandle(gTask.Type)
if handle == nil {
logutil.BgLogger().Warn("gen gTask flow handle failed, this type handle doesn't register", zap.Int64("ID", gTask.ID), zap.String("type", gTask.Type))
logutil.BgLogger().Warn("gen task flow handle failed, this type handle doesn't register", zap.Int64("ID", gTask.ID), zap.String("type", gTask.Type))
gTask.Error = errors.New("unsupported task type")
return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes)
}
Expand Down Expand Up @@ -436,7 +448,7 @@
pos := i % len(serverNodes)
instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port)
logutil.BgLogger().Debug("create subtasks",
zap.Int("gTask.ID", int(gTask.ID)), zap.String("type", gTask.Type), zap.String("instanceID", instanceID))
zap.Int("task.ID", int(gTask.ID)), zap.String("type", gTask.Type), zap.String("instanceID", instanceID))
subTasks = append(subTasks, proto.NewSubtask(gTask.ID, gTask.Type, instanceID, meta))
}

Expand Down Expand Up @@ -483,6 +495,7 @@
return ids, nil
}

// GetPreviousSubtaskMetas get subtask metas from specific step.
func (d *dispatcher) GetPreviousSubtaskMetas(gTaskID int64, step int64) ([][]byte, error) {
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(gTaskID, step)
if err != nil {
Expand All @@ -496,17 +509,19 @@
return previousSubtaskMetas, nil
}

// WithNewSession executes the function with a new session.
func (d *dispatcher) WithNewSession(fn func(se sessionctx.Context) error) error {
return d.taskMgr.WithNewSession(fn)
}

// WithNewTxn executes the fn in a new transaction.
func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error {
return d.taskMgr.WithNewTxn(ctx, fn)
}

func (*dispatcher) checkConcurrencyOverflow(cnt int) bool {
if cnt >= DefaultDispatchConcurrency {
logutil.BgLogger().Info("dispatch task loop, running GTask cnt is more than concurrency",
logutil.BgLogger().Info("dispatch task loop, running task cnt is more than concurrency limitation",
zap.Int("running cnt", cnt), zap.Int("concurrency", DefaultDispatchConcurrency))
return true
}
Expand Down
26 changes: 21 additions & 5 deletions disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (*testFlowHandle) IsRetryableErr(error) bool {
return true
}

func MockDispatcher(t *testing.T, pool *pools.ResourcePool) (dispatcher.Dispatch, *storage.TaskManager) {
func MockDispatcher(t *testing.T, pool *pools.ResourcePool) (dispatcher.Dispatcher, *storage.TaskManager) {
ctx := context.Background()
mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool)
storage.SetTaskManager(mgr)
Expand Down Expand Up @@ -189,6 +189,23 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
require.Equal(t, retCnt, taskCnt)
}

checkTaskRunningCnt := func() []*proto.Task {
var retCnt int
var tasks []*proto.Task
var err error
for i := 0; i < cnt; i++ {
tasks, err = mgr.GetGlobalTasksInStates(proto.TaskStateRunning)
require.NoError(t, err)
retCnt = len(tasks)
if retCnt == taskCnt {
break
}
time.Sleep(time.Millisecond * 50)
}
require.Equal(t, retCnt, taskCnt)
return tasks
}

// Mock add tasks.
taskIDs := make([]int64, 0, taskCnt)
for i := 0; i < taskCnt; i++ {
Expand All @@ -198,9 +215,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
}
// test normal flow
checkGetRunningGTaskCnt()
tasks, err := mgr.GetGlobalTasksInStates(proto.TaskStateRunning)
require.NoError(t, err)
require.Len(t, tasks, taskCnt)
tasks := checkTaskRunningCnt()
for i, taskID := range taskIDs {
require.Equal(t, int64(i+1), tasks[i].ID)
subtasks, err := mgr.GetSubtaskInStatesCnt(taskID, proto.TaskStatePending)
Expand All @@ -220,7 +235,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
// test DetectTaskLoop
checkGetGTaskState := func(expectedState string) {
for i := 0; i < cnt; i++ {
tasks, err = mgr.GetGlobalTasksInStates(expectedState)
tasks, err := mgr.GetGlobalTasksInStates(expectedState)
require.NoError(t, err)
if len(tasks) == taskCnt {
break
Expand All @@ -229,6 +244,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
}
}
// Test all subtasks are successful.
var err error
if isSucc {
// Mock subtasks succeed.
for i := 1; i <= subtaskCnt*taskCnt; i++ {
Expand Down
4 changes: 2 additions & 2 deletions disttask/framework/framework_rollback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ func TestFrameworkRollback(t *testing.T) {
var v atomic.Int64
RegisterRollbackTaskMeta(&v)
distContext := testkit.NewDistExecutionContext(t, 2)
failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelTaskBeforeProbe", "1*return(true)")
failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelTaskAfterMonitorTask", "2*return(true)")
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelTaskBeforeProbe"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelTaskAfterMonitorTask"))
}()

DispatchTaskAndCheckFail("key2", t, &v)
Expand Down
Loading