Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: refactor task step and batch subtask dispatching #46957

Merged
merged 8 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions ddl/backfilling_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ func (h *backfillingDispatcherExt) OnNextSubtasksBatch(ctx context.Context,
}
}

// StageFinished check if current stage finished.
func (*backfillingDispatcherExt) StageFinished(_ *proto.Task) bool {
return true
}

// Finished check if current task finished.
func (*backfillingDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepOne
func (*backfillingDispatcherExt) GetNextStep(task *proto.Task) int64 {
switch task.Step {
case proto.StepInit:
return proto.StepOne
case proto.StepOne:
return proto.StepTwo
default:
// current step should be proto.StepOne
return proto.StepDone
}
}

// OnErrStage generate error handling stage's plan.
Expand Down Expand Up @@ -372,7 +374,7 @@ func getSummaryFromLastStep(
taskHandle dispatcher.TaskHandle,
gTaskID int64,
) (min, max kv.Key, totalKVSize uint64, dataFiles, statFiles []string, err error) {
subTaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTaskID, proto.StepInit)
subTaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTaskID, proto.StepOne)
if err != nil {
return nil, nil, 0, nil, nil, errors.Trace(err)
}
Expand Down
21 changes: 19 additions & 2 deletions ddl/backfilling_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,19 @@ func TestBackfillingDispatcher(t *testing.T) {

// 1.2 test partition table OnNextSubtasksBatch after StepInit finished.
gTask.State = proto.TaskStateRunning
gTask.Step++
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepOne, gTask.Step)
// empty stepTwo
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepTwo, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepDone, gTask.Step)

// 1.3 test partition table OnErrStage.
errMeta, err := dsp.OnErrStage(context.Background(), nil, gTask, []error{errors.New("mockErr")})
Expand Down Expand Up @@ -93,12 +102,20 @@ func TestBackfillingDispatcher(t *testing.T) {
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepOne, gTask.Step)
// 2.2.2 stepOne
gTask.Step++
gTask.State = proto.TaskStateRunning
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepTwo, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepDone, gTask.Step)
}

func createAddIndexGlobalTask(t *testing.T, dom *domain.Domain, dbName, tblName string, taskType string) *proto.Task {
Expand Down
2 changes: 1 addition & 1 deletion ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -1959,7 +1959,7 @@ func (w *worker) updateJobRowCount(taskKey string, jobID int64) {
logutil.BgLogger().Warn("cannot get global task", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
}
rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepInit)
rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepOne)
if err != nil {
logutil.BgLogger().Warn("cannot get subtask row count", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
Expand Down
6 changes: 3 additions & 3 deletions ddl/stage_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ func NewBackfillSchedulerHandle(ctx context.Context, taskMeta []byte, d *ddl,
}

switch stage {
case proto.StepInit:
case proto.StepOne:
jc := d.jobContext(jobMeta.ID, jobMeta.ReorgMeta)
d.setDDLLabelForTopSQL(jobMeta.ID, jobMeta.Query)
d.setDDLSourceForDiagnosis(jobMeta.ID, jobMeta.Type)
return newReadIndexStage(
d, &bgm.Job, indexInfo, tbl.(table.PhysicalTable), jc, bc, summary, bgm.CloudStorageURI), nil
case proto.StepOne:
case proto.StepTwo:
if len(bgm.CloudStorageURI) > 0 {
return newMergeSortStage(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc, bgm.CloudStorageURI)
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func newBackfillDistScheduler(ctx context.Context, id string, taskID int64, task

func (s *backfillDistScheduler) GetSubtaskExecutor(ctx context.Context, task *proto.Task, summary *execute.Summary) (execute.SubtaskExecutor, error) {
switch task.Step {
case proto.StepInit, proto.StepOne:
case proto.StepOne, proto.StepTwo:
return NewBackfillSchedulerHandle(ctx, task.Meta, s.d, task.Step, summary)
default:
return nil, errors.Errorf("unknown backfill step %d for task %d", task.Step, task.ID)
Expand Down
76 changes: 46 additions & 30 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,6 @@ func (d *BaseDispatcher) onRunning() error {
}

if cnt == 0 {
logutil.Logger(d.logCtx).Info("previous subtasks finished, generate dist plan", zap.Int64("stage", d.Task.Step))
// When all subtasks dispatched and processed, mark task as succeed.
if d.Finished(d.Task) {
d.Task.StateUpdateTime = time.Now().UTC()
logutil.Logger(d.logCtx).Info("all subtasks dispatched and processed, finish the task")
err := d.UpdateTask(proto.TaskStateSucceed, nil, RetrySQLTimes)
if err != nil {
logutil.Logger(d.logCtx).Warn("update task failed", zap.Error(err))
return err
}
return nil
}
return d.onNextStage()
}
// Check if any node are down.
Expand Down Expand Up @@ -404,17 +392,29 @@ func (d *BaseDispatcher) dispatchSubTask4Revert(meta []byte) error {

subTasks := make([]*proto.Subtask, 0, len(instanceIDs))
for _, id := range instanceIDs {
subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, id, meta))
// reverting subtasks belong to the same step as current active step.
subTasks = append(subTasks, proto.NewSubtask(d.Task.Step, d.Task.ID, d.Task.Type, id, meta))
}
return d.UpdateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes)
}

func (d *BaseDispatcher) onNextStage() error {
func (*BaseDispatcher) nextStepSubtaskDispatched(*proto.Task) bool {
// TODO: will implement it when we we support dispatch subtask by batch.
// since subtask meta might be too large to save in one transaction.
return true
}

func (d *BaseDispatcher) onNextStage() (err error) {
/// dynamic dispatch subtasks.
failpoint.Inject("mockDynamicDispatchErr", func() {
failpoint.Return(errors.New("mockDynamicDispatchErr"))
})

nextStep := d.GetNextStep(d.Task)
logutil.Logger(d.logCtx).Info("onNextStage",
zap.Int64("current-step", d.Task.Step),
zap.Int64("next-step", nextStep))

// 1. Adjust the global task's concurrency.
if d.Task.State == proto.TaskStatePending {
if d.Task.Concurrency == 0 {
Expand All @@ -423,20 +423,36 @@ func (d *BaseDispatcher) onNextStage() error {
if d.Task.Concurrency > MaxSubtaskConcurrency {
d.Task.Concurrency = MaxSubtaskConcurrency
}
d.Task.StateUpdateTime = time.Now().UTC()
if err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes); err != nil {
return err
}
} else if d.StageFinished(d.Task) {
// 2. when previous stage finished, update to next stage.
d.Task.Step++
logutil.Logger(d.logCtx).Info("previous stage finished, run into next stage", zap.Int64("from", d.Task.Step-1), zap.Int64("to", d.Task.Step))
d.Task.StateUpdateTime = time.Now().UTC()
err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes)
}
defer func() {
if err != nil {
return err
return
}
}
// invariant: task.Step always means the most recent step that all
// corresponding subtasks have been saved to system table.
//
// when all subtasks of task.Step is finished, we call OnNextSubtasksBatch
// to generate subtasks of next step. after all subtasks of next step are
// saved to system table, we will update task.Step to next step, so the
// invariant hold.
// see nextStepSubtaskDispatched for why we don't update task and subtasks
// in a single transaction.
if d.nextStepSubtaskDispatched(d.Task) {
currStep := d.Task.Step
d.Task.Step = nextStep
// When all subtasks dispatched and processed, mark task as succeed.
taskState := proto.TaskStateRunning
if d.Task.Step == proto.StepDone {
taskState = proto.TaskStateSucceed
logutil.Logger(d.logCtx).Info("all subtasks dispatched and processed, finish the task")
} else {
logutil.Logger(d.logCtx).Info("move to next stage",
zap.Int64("from", currStep), zap.Int64("to", d.Task.Step))
}
d.Task.StateUpdateTime = time.Now().UTC()
err = d.UpdateTask(taskState, nil, RetrySQLTimes)
}
}()

for {
// 3. generate a batch of subtasks.
Expand All @@ -451,12 +467,12 @@ func (d *BaseDispatcher) onNextStage() error {
})

// 4. dispatch batch of subtasks to EligibleInstances.
err = d.dispatchSubTask(metas)
err = d.dispatchSubTask(nextStep, metas)
if err != nil {
return err
}

if d.StageFinished(d.Task) {
if d.nextStepSubtaskDispatched(d.Task) {
break
}

Expand All @@ -467,7 +483,7 @@ func (d *BaseDispatcher) onNextStage() error {
return nil
}

func (d *BaseDispatcher) dispatchSubTask(metas [][]byte) error {
func (d *BaseDispatcher) dispatchSubTask(subtaskStep int64, metas [][]byte) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))

// select all available TiDB nodes for task.
Expand Down Expand Up @@ -499,7 +515,7 @@ func (d *BaseDispatcher) dispatchSubTask(metas [][]byte) error {
pos := i % len(serverNodes)
instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port)
logutil.Logger(d.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID))
subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, instanceID, meta))
subTasks = append(subTasks, proto.NewSubtask(subtaskStep, d.Task.ID, d.Task.Type, instanceID, meta))
}
return d.addSubtasks(subTasks)
}
Expand Down
21 changes: 9 additions & 12 deletions disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,8 @@ func (*testDispatcherExt) IsRetryableErr(error) bool {
return true
}

func (dsp *testDispatcherExt) StageFinished(task *proto.Task) bool {
return true
}

func (dsp *testDispatcherExt) Finished(task *proto.Task) bool {
return false
func (*testDispatcherExt) GetNextStep(*proto.Task) int64 {
return proto.StepDone
}

type numberExampleDispatcherExt struct{}
Expand Down Expand Up @@ -108,12 +104,13 @@ func (*numberExampleDispatcherExt) IsRetryableErr(error) bool {
return true
}

func (*numberExampleDispatcherExt) StageFinished(task *proto.Task) bool {
return true
}

func (*numberExampleDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepTwo
func (*numberExampleDispatcherExt) GetNextStep(task *proto.Task) int64 {
switch task.Step {
case proto.StepInit:
return proto.StepOne
default:
return proto.StepDone
}
}

func MockDispatcherManager(t *testing.T, pool *pools.ResourcePool) (*dispatcher.Manager, *storage.TaskManager) {
Expand Down
13 changes: 5 additions & 8 deletions disttask/framework/dispatcher/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type Extension interface {
// it's called when:
// 1. task is pending and entering it's first step.
// 2. subtasks dispatched has all finished with no error.
// when next step is StepDone, it should return nil, nil.
OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task) (subtaskMetas [][]byte, err error)

// OnErrStage is called when:
Expand All @@ -52,14 +53,10 @@ type Extension interface {
// IsRetryableErr is used to check whether the error occurred in dispatcher is retryable.
IsRetryableErr(err error) bool

// StageFinished is used to check if all subtasks in current stage are dispatched and processed.
// StageFinished is called before generating batch of subtasks.
StageFinished(task *proto.Task) bool

// Finished is used to check if all subtasks for the task are dispatched and processed.
// Finished is called before generating batch of subtasks.
// Once Finished return true, mark the task as succeed.
Finished(task *proto.Task) bool
// GetNextStep is used to get the next step for the task.
// if task runs successfully, it should go from StepInit to business steps,
// then to StepDone, then dispatcher will mark it as finished.
GetNextStep(task *proto.Task) int64
}

// FactoryFn is used to create a dispatcher.
Expand Down
23 changes: 10 additions & 13 deletions disttask/framework/framework_dynamic_dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (*testDynamicDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {}

func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
// step1
if gTask.Step == proto.StepInit && dsp.cnt < 3 {
if gTask.Step == proto.StepInit {
dsp.cnt++
return [][]byte{
[]byte(fmt.Sprintf("task%d", dsp.cnt)),
Expand All @@ -48,7 +48,7 @@ func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ di
}

// step2
if gTask.Step == proto.StepOne && dsp.cnt < 4 {
if gTask.Step == proto.StepOne {
dsp.cnt++
return [][]byte{
[]byte(fmt.Sprintf("task%d", dsp.cnt)),
Expand All @@ -61,18 +61,15 @@ func (*testDynamicDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.Task
return nil, nil
}

func (dsp *testDynamicDispatcherExt) StageFinished(task *proto.Task) bool {
if task.Step == proto.StepInit && dsp.cnt >= 3 {
return true
func (dsp *testDynamicDispatcherExt) GetNextStep(task *proto.Task) int64 {
switch task.Step {
case proto.StepInit:
return proto.StepOne
case proto.StepOne:
return proto.StepTwo
default:
return proto.StepDone
}
if task.Step == proto.StepOne && dsp.cnt >= 4 {
return true
}
return false
}

func (dsp *testDynamicDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepOne && dsp.cnt >= 4
}

func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
Expand Down
Loading