Skip to content

Commit

Permalink
disttask: alloc or occupy resource when there are subtasks to run & f…
Browse files Browse the repository at this point in the history
…ix scheduler block task (#50650)

ref #49008
  • Loading branch information
D3Hunter authored Jan 24, 2024
1 parent 7c5f9d4 commit 40f0d16
Show file tree
Hide file tree
Showing 15 changed files with 285 additions and 145 deletions.
16 changes: 16 additions & 0 deletions pkg/disttask/framework/mock/task_executor_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pkg/disttask/framework/scheduler/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func (sm *Manager) DelRunningTask(id int64) {
sm.delScheduler(id)
}

// DoCleanUpRoutine implements Scheduler.DoCleanUpRoutine interface.
func (sm *Manager) DoCleanUpRoutine() {
// DoCleanupRoutine implements Scheduler.DoCleanupRoutine interface.
func (sm *Manager) DoCleanupRoutine() {
sm.doCleanupTask()
}

Expand Down
7 changes: 6 additions & 1 deletion pkg/disttask/framework/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ func (s *BaseScheduler) refreshTask() error {
task := s.GetTask()
newTask, err := s.taskMgr.GetTaskByID(s.ctx, task.ID)
if err != nil {
s.logger.Error("refresh task failed", zap.Error(err))
return err
}
s.task.Store(newTask)
Expand All @@ -156,6 +155,12 @@ func (s *BaseScheduler) scheduleTask() {
case <-ticker.C:
err := s.refreshTask()
if err != nil {
if errors.Cause(err) == storage.ErrTaskNotFound {
// this can happen when task is reverted/succeed, but before
// we reach here, cleanup routine move it to history.
return
}
s.logger.Error("refresh task failed", zap.Error(err))
continue
}
task := *s.GetTask()
Expand Down
33 changes: 17 additions & 16 deletions pkg/disttask/framework/scheduler/scheduler_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var (
checkTaskRunningInterval = 3 * time.Second
// defaultHistorySubtaskTableGcInterval is the interval of gc history subtask table.
defaultHistorySubtaskTableGcInterval = 24 * time.Hour
// DefaultCleanUpInterval is the interval of cleanUp routine.
// DefaultCleanUpInterval is the interval of cleanup routine.
DefaultCleanUpInterval = 10 * time.Minute
)

Expand Down Expand Up @@ -318,13 +318,13 @@ func (sm *Manager) startScheduler(basicTask *proto.Task, reservedExecID string)
}

func (sm *Manager) cleanupTaskLoop() {
logutil.BgLogger().Info("cleanUp loop start")
logutil.BgLogger().Info("cleanup loop start")
ticker := time.NewTicker(DefaultCleanUpInterval)
defer ticker.Stop()
for {
select {
case <-sm.ctx.Done():
logutil.BgLogger().Info("cleanUp loop exits", zap.Error(sm.ctx.Err()))
logutil.BgLogger().Info("cleanup loop exits", zap.Error(sm.ctx.Err()))
return
case <-sm.finishCh:
sm.doCleanupTask()
Expand All @@ -337,7 +337,7 @@ func (sm *Manager) cleanupTaskLoop() {
// WaitCleanUpFinished is used to sync the test.
var WaitCleanUpFinished = make(chan struct{}, 1)

// doCleanupTask processes clean up routine defined by each type of tasks and cleanUpMeta.
// doCleanupTask processes clean up routine defined by each type of tasks and cleanupMeta.
// For example:
//
// tasks with global sort should clean up tmp files stored on S3.
Expand All @@ -349,44 +349,45 @@ func (sm *Manager) doCleanupTask() {
proto.TaskStateSucceed,
)
if err != nil {
logutil.BgLogger().Warn("cleanUp routine failed", zap.Error(err))
logutil.BgLogger().Warn("cleanup routine failed", zap.Error(err))
return
}
if len(tasks) == 0 {
return
}
logutil.BgLogger().Info("cleanUp routine start")
err = sm.cleanUpFinishedTasks(tasks)
logutil.BgLogger().Info("cleanup routine start")
err = sm.cleanupFinishedTasks(tasks)
if err != nil {
logutil.BgLogger().Warn("cleanUp routine failed", zap.Error(err))
logutil.BgLogger().Warn("cleanup routine failed", zap.Error(err))
return
}
failpoint.Inject("WaitCleanUpFinished", func() {
WaitCleanUpFinished <- struct{}{}
})
logutil.BgLogger().Info("cleanUp routine success")
logutil.BgLogger().Info("cleanup routine success")
}

func (sm *Manager) cleanUpFinishedTasks(tasks []*proto.Task) error {
func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error {
cleanedTasks := make([]*proto.Task, 0)
var firstErr error
for _, task := range tasks {
cleanUpFactory := getSchedulerCleanUpFactory(task.Type)
if cleanUpFactory != nil {
cleanUp := cleanUpFactory()
err := cleanUp.CleanUp(sm.ctx, task)
logutil.BgLogger().Info("cleanup task", zap.Int64("task-id", task.ID))
cleanupFactory := getSchedulerCleanUpFactory(task.Type)
if cleanupFactory != nil {
cleanup := cleanupFactory()
err := cleanup.CleanUp(sm.ctx, task)
if err != nil {
firstErr = err
break
}
cleanedTasks = append(cleanedTasks, task)
} else {
// if task doesn't register cleanUp function, mark it as cleaned.
// if task doesn't register cleanup function, mark it as cleaned.
cleanedTasks = append(cleanedTasks, task)
}
}
if firstErr != nil {
logutil.BgLogger().Warn("cleanUp routine failed", zap.Error(errors.Trace(firstErr)))
logutil.BgLogger().Warn("cleanup routine failed", zap.Error(errors.Trace(firstErr)))
}

failpoint.Inject("mockTransferErr", func() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/disttask/framework/scheduler/scheduler_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestCleanUpRoutine(t *testing.T) {
err = mgr.UpdateSubtaskStateAndError(ctx, ":4000", int64(i), proto.SubtaskStateSucceed, nil)
require.NoError(t, err)
}
sch.DoCleanUpRoutine()
sch.DoCleanupRoutine()
require.Eventually(t, func() bool {
tasks, err := testutil.GetTasksFromHistoryInStates(ctx, mgr, proto.TaskStateSucceed)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/disttask/framework/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func getNumberExampleSchedulerExt(ctrl *gomock.Controller) scheduler.Extension {
return mockScheduler
}

func MockSchedulerManager(t *testing.T, ctrl *gomock.Controller, pool *pools.ResourcePool, ext scheduler.Extension, cleanUp scheduler.CleanUpRoutine) (*scheduler.Manager, *storage.TaskManager) {
func MockSchedulerManager(t *testing.T, ctrl *gomock.Controller, pool *pools.ResourcePool, ext scheduler.Extension, cleanup scheduler.CleanUpRoutine) (*scheduler.Manager, *storage.TaskManager) {
ctx := context.WithValue(context.Background(), "etcd", true)
mgr := storage.NewTaskManager(pool)
storage.SetTaskManager(mgr)
Expand Down
2 changes: 1 addition & 1 deletion pkg/disttask/framework/storage/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ go_test(
embed = [":storage"],
flaky = True,
race = "on",
shard_count = 21,
shard_count = 22,
deps = [
"//pkg/config",
"//pkg/disttask/framework/proto",
Expand Down
65 changes: 65 additions & 0 deletions pkg/disttask/framework/storage/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1103,3 +1103,68 @@ func TestSubtasksState(t *testing.T) {
require.NoError(t, err)
require.Greater(t, endTime, ts)
}

func checkBasicTaskEq(t *testing.T, expectedTask, task *proto.Task) {
require.Equal(t, expectedTask.ID, task.ID)
require.Equal(t, expectedTask.Key, task.Key)
require.Equal(t, expectedTask.Type, task.Type)
require.Equal(t, expectedTask.State, task.State)
require.Equal(t, expectedTask.Step, task.Step)
require.Equal(t, expectedTask.Priority, task.Priority)
require.Equal(t, expectedTask.Concurrency, task.Concurrency)
require.Equal(t, expectedTask.CreateTime, task.CreateTime)
}

func TestGetActiveTaskExecInfo(t *testing.T) {
_, tm, ctx := testutil.InitTableTest(t)

require.NoError(t, tm.InitMeta(ctx, ":4000", ""))
taskStates := []proto.TaskState{proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateReverting, proto.TaskStatePausing}
tasks := make([]*proto.Task, 0, len(taskStates))
for i, expectedState := range taskStates {
taskID, err := tm.CreateTask(ctx, fmt.Sprintf("key-%d", i), proto.TaskTypeExample, 8, []byte(""))
require.NoError(t, err)
task, err := tm.GetTaskByID(ctx, taskID)
require.NoError(t, err)
tasks = append(tasks, task)
require.NoError(t, tm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepTwo, nil))
task.State = expectedState
task.Step = proto.StepTwo
switch expectedState {
case proto.TaskStateReverting:
require.NoError(t, tm.RevertTask(ctx, task.ID, proto.TaskStateRunning, nil))
case proto.TaskStatePausing:
_, err = tm.PauseTask(ctx, task.Key)
require.NoError(t, err)
}
}
// mock a pending subtask of step 1, this should not happen, just for test
testutil.InsertSubtask(t, tm, tasks[0].ID, proto.StepOne, ":4000", []byte("test"), proto.SubtaskStatePending, proto.TaskTypeExample, 4)
testutil.InsertSubtask(t, tm, tasks[0].ID, proto.StepTwo, ":4000", []byte("test"), proto.SubtaskStatePending, proto.TaskTypeExample, 4)
testutil.InsertSubtask(t, tm, tasks[0].ID, proto.StepTwo, ":4000", []byte("test"), proto.SubtaskStateRunning, proto.TaskTypeExample, 4)
testutil.InsertSubtask(t, tm, tasks[0].ID, proto.StepTwo, ":4001", []byte("test"), proto.SubtaskStateSucceed, proto.TaskTypeExample, 4)
testutil.InsertSubtask(t, tm, tasks[0].ID, proto.StepTwo, ":4001", []byte("test"), proto.SubtaskStatePending, proto.TaskTypeExample, 4)
// task 1 has no subtask
testutil.InsertSubtask(t, tm, tasks[2].ID, proto.StepTwo, ":4001", []byte("test"), proto.SubtaskStatePending, proto.TaskTypeExample, 6)
testutil.InsertSubtask(t, tm, tasks[3].ID, proto.StepTwo, ":4001", []byte("test"), proto.SubtaskStateRunning, proto.TaskTypeExample, 8)

subtasks, err2 := tm.GetActiveSubtasks(ctx, 1)
require.NoError(t, err2)
_ = subtasks
// :4000
taskExecInfos, err := tm.GetTaskExecInfoByExecID(ctx, ":4000")
require.NoError(t, err)
require.Len(t, taskExecInfos, 1)
checkBasicTaskEq(t, tasks[0], taskExecInfos[0].Task)
require.Equal(t, 4, taskExecInfos[0].SubtaskConcurrency)
// :4001
taskExecInfos, err = tm.GetTaskExecInfoByExecID(ctx, ":4001")
require.NoError(t, err)
require.Len(t, taskExecInfos, 3)
checkBasicTaskEq(t, tasks[0], taskExecInfos[0].Task)
require.Equal(t, 4, taskExecInfos[0].SubtaskConcurrency)
checkBasicTaskEq(t, tasks[2], taskExecInfos[1].Task)
require.Equal(t, 6, taskExecInfos[1].SubtaskConcurrency)
checkBasicTaskEq(t, tasks[3], taskExecInfos[2].Task)
require.Equal(t, 8, taskExecInfos[2].SubtaskConcurrency)
}
56 changes: 46 additions & 10 deletions pkg/disttask/framework/storage/task_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ import (
const (
defaultSubtaskKeepDays = 14

basicTaskColumns = `id, task_key, type, state, step, priority, concurrency, create_time`
basicTaskColumns = `t.id, t.task_key, t.type, t.state, t.step, t.priority, t.concurrency, t.create_time`
// TaskColumns is the columns for task.
// TODO: dispatcher_id will update to scheduler_id later
TaskColumns = basicTaskColumns + `, start_time, state_update_time, meta, dispatcher_id, error`
TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error`
// InsertTaskColumns is the columns used in insert task.
InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time`
basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal`
Expand Down Expand Up @@ -70,6 +70,17 @@ var (
ErrSubtaskNotFound = errors.New("subtask not found")
)

// TaskExecInfo is the execution information of a task, on some exec node.
type TaskExecInfo struct {
*proto.Task
// SubtaskConcurrency is the concurrency of subtask in current task step.
// TODO: will be used when support subtask have smaller concurrency than task,
// TODO: such as post-process of import-into.
// TODO: we might need create one task executor for each step in this case, to alloc
// TODO: minimal resource
SubtaskConcurrency int
}

// SessionExecutor defines the interface for executing SQLs in a session.
type SessionExecutor interface {
// WithNewSession executes the function with a new session.
Expand Down Expand Up @@ -221,7 +232,7 @@ func (mgr *TaskManager) CreateTaskWithSession(ctx context.Context, se sessionctx
// GetTopUnfinishedTasks implements the scheduler.TaskManager interface.
func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) (task []*proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx,
`select `+basicTaskColumns+` from mysql.tidb_global_task
`select `+basicTaskColumns+` from mysql.tidb_global_task t
where state in (%?, %?, %?, %?, %?, %?)
order by priority asc, create_time asc, id asc
limit %?`,
Expand All @@ -243,14 +254,39 @@ func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) (task []*prot
return task, nil
}

// GetTaskExecInfoByExecID implements the scheduler.TaskManager interface.
func (mgr *TaskManager) GetTaskExecInfoByExecID(ctx context.Context, execID string) ([]*TaskExecInfo, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx,
`select `+TaskColumns+`, max(st.concurrency)
from mysql.tidb_global_task t join mysql.tidb_background_subtask st
on t.id = st.task_key and t.step = st.step
where t.state in (%?, %?, %?) and st.state in (%?, %?) and st.exec_id = %?
group by t.id
order by priority asc, create_time asc, id asc`,
proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStatePausing,
proto.SubtaskStatePending, proto.SubtaskStateRunning, execID)
if err != nil {
return nil, err
}

res := make([]*TaskExecInfo, 0, len(rs))
for _, r := range rs {
res = append(res, &TaskExecInfo{
Task: Row2Task(r),
SubtaskConcurrency: int(r.GetInt64(13)),
})
}
return res, nil
}

// GetTasksInStates gets the tasks in the states(order by priority asc, create_time acs, id asc).
func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) {
if len(states) == 0 {
return task, nil
}

rs, err := mgr.ExecuteSQLWithNewSession(ctx,
"select "+TaskColumns+" from mysql.tidb_global_task "+
"select "+TaskColumns+" from mysql.tidb_global_task t "+
"where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)"+
" order by priority asc, create_time asc, id asc", states...)
if err != nil {
Expand All @@ -265,7 +301,7 @@ func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...interfac

// GetTaskByID gets the task by the task ID.
func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where id = %?", taskID)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where id = %?", taskID)
if err != nil {
return task, err
}
Expand All @@ -278,8 +314,8 @@ func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *pr

// GetTaskByIDWithHistory gets the task by the task ID from both tidb_global_task and tidb_global_task_history.
func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where id = %? "+
"union select "+TaskColumns+" from mysql.tidb_global_task_history where id = %?", taskID, taskID)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where id = %? "+
"union select "+TaskColumns+" from mysql.tidb_global_task_history t where id = %?", taskID, taskID)
if err != nil {
return task, err
}
Expand All @@ -292,7 +328,7 @@ func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64

// GetTaskByKey gets the task by the task key.
func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where task_key = %?", key)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where task_key = %?", key)
if err != nil {
return task, err
}
Expand All @@ -305,8 +341,8 @@ func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *pro

// GetTaskByKeyWithHistory gets the task from history table by the task key.
func (mgr *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where task_key = %?"+
"union select "+TaskColumns+" from mysql.tidb_global_task_history where task_key = %?", key, key)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where task_key = %?"+
"union select "+TaskColumns+" from mysql.tidb_global_task_history t where task_key = %?", key, key)
if err != nil {
return task, err
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/disttask/framework/taskexecutor/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ import (
"context"

"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/disttask/framework/storage"
"github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute"
)

// TaskTable defines the interface to access the task table.
type TaskTable interface {
// GetTaskExecInfoByExecID gets all task exec infos by given execID, if there's
// no executable subtask on the execID for some task, it's not returned.
GetTaskExecInfoByExecID(ctx context.Context, execID string) ([]*storage.TaskExecInfo, error)
GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error)
GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error)
// GetSubtasksByExecIDAndStepAndStates gets all subtasks by given states and execID.
Expand Down
Loading

0 comments on commit 40f0d16

Please sign in to comment.