Skip to content

Commit

Permalink
disttask: fix wrong plan when filterByRole (#48366)
Browse files Browse the repository at this point in the history
close #48368
  • Loading branch information
ywqzzy authored Nov 8, 2023
1 parent 57a1a9c commit d4618d4
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 87 deletions.
14 changes: 5 additions & 9 deletions pkg/ddl/backfilling_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (dsp *BackfillingDispatcherExt) OnNextSubtasksBatch(
ctx context.Context,
taskHandle dispatcher.TaskHandle,
gTask *proto.Task,
serverInfo []*infosync.ServerInfo,
nextStep proto.Step,
) (taskMeta [][]byte, err error) {
logger := logutil.BgLogger().With(
Expand All @@ -95,12 +96,7 @@ func (dsp *BackfillingDispatcherExt) OnNextSubtasksBatch(
if tblInfo.Partition != nil {
return generatePartitionPlan(tblInfo)
}
is, err := dsp.GetEligibleInstances(ctx, gTask)
if err != nil {
return nil, err
}
instanceCnt := len(is)
return generateNonPartitionPlan(dsp.d, tblInfo, job, dsp.GlobalSort, instanceCnt)
return generateNonPartitionPlan(dsp.d, tblInfo, job, dsp.GlobalSort, len(serverInfo))
case StepMergeSort:
res, err := generateMergePlan(taskHandle, gTask, logger)
if err != nil {
Expand Down Expand Up @@ -195,12 +191,12 @@ func (*BackfillingDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.Task
}

// GetEligibleInstances implements dispatcher.Extension interface.
func (*BackfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*BackfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
serverInfos, err := dispatcher.GenerateSchedulerNodes(ctx)
if err != nil {
return nil, err
return nil, true, err
}
return serverInfos, nil
return serverInfos, true, nil
}

// IsRetryableErr implements dispatcher.Extension.IsRetryableErr interface.
Expand Down
22 changes: 13 additions & 9 deletions pkg/ddl/backfilling_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
// 1.1 OnNextSubtasksBatch
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, ddl.StepReadIndex, gTask.Step)
metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
serverInfos, _, err := dsp.GetEligibleInstances(context.Background(), gTask)
require.NoError(t, err)
metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, len(tblInfo.Partition.Definitions), len(metas))
for i, par := range tblInfo.Partition.Definitions {
Expand All @@ -83,7 +85,7 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
gTask.State = proto.TaskStateRunning
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepDone, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Len(t, metas, 0)

Expand All @@ -100,7 +102,7 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
// 2.1 empty table
tk.MustExec("create table t1(id int primary key, v int)")
gTask = createAddIndexGlobalTask(t, dom, "test", "t1", proto.Backfill, false)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
// 2.2 non empty table.
Expand All @@ -112,15 +114,15 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
gTask = createAddIndexGlobalTask(t, dom, "test", "t2", proto.Backfill, false)
// 2.2.1 stepInit
gTask.Step = dsp.GetNextStep(gTask)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
require.Equal(t, ddl.StepReadIndex, gTask.Step)
// 2.2.2 StepReadIndex
gTask.State = proto.TaskStateRunning
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepDone, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
}
Expand Down Expand Up @@ -176,9 +178,11 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
taskID, err := mgr.AddNewGlobalTask(task.Key, proto.Backfill, 1, task.Meta)
require.NoError(t, err)
task.ID = taskID
serverInfos, _, err := dsp.GetEligibleInstances(context.Background(), task)
require.NoError(t, err)

// 1. to read-index stage
subtaskMetas, err := dsp.OnNextSubtasksBatch(ctx, dsp, task, dsp.GetNextStep(task))
subtaskMetas, err := dsp.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, dsp.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
task.Step = ext.GetNextStep(task)
Expand Down Expand Up @@ -219,7 +223,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/forceMergeSort"))
})
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, ext.GetNextStep(task))
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, ext.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
task.Step = ext.GetNextStep(task)
Expand Down Expand Up @@ -258,13 +262,13 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/mockWriteIngest"))
})
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, ext.GetNextStep(task))
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, ext.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
task.Step = ext.GetNextStep(task)
require.Equal(t, ddl.StepWriteAndIngest, task.Step)
// 4. to done stage.
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, ext.GetNextStep(task))
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, ext.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 0)
task.Step = ext.GetNextStep(task)
Expand Down
65 changes: 41 additions & 24 deletions pkg/disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,17 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
if err != nil {
return err
}
eligibleServerInfos, err := d.GetEligibleInstances(d.ctx, d.Task)

eligibleServerInfos, filter, err := d.GetEligibleInstances(d.ctx, d.Task)
if err != nil {
return err
}
if filter {
eligibleServerInfos, err = d.filterByRole(eligibleServerInfos)
if err != nil {
return err
}
}
newInfos := serverInfos[:0]
for _, m := range serverInfos {
found := false
Expand Down Expand Up @@ -545,7 +552,25 @@ func (d *BaseDispatcher) onNextStage() (err error) {

for {
// 3. generate a batch of subtasks.
metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task, nextStep)
/// select all available TiDB nodes for task.
serverNodes, filter, err := d.GetEligibleInstances(d.ctx, d.Task)
logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes)))

if err != nil {
return err
}
if filter {
serverNodes, err = d.filterByRole(serverNodes)
if err != nil {
return err
}
}
logutil.Logger(d.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes)))
if len(serverNodes) == 0 {
return errors.New("no available TiDB node to dispatch subtasks")
}

metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task, serverNodes, nextStep)
if err != nil {
logutil.Logger(d.logCtx).Warn("generate part of subtasks failed", zap.Error(err))
return d.handlePlanErr(err)
Expand All @@ -556,7 +581,7 @@ func (d *BaseDispatcher) onNextStage() (err error) {
})

// 4. dispatch batch of subtasks to EligibleInstances.
err = d.dispatchSubTask(nextStep, metas)
err = d.dispatchSubTask(nextStep, metas, serverNodes)
if err != nil {
return err
}
Expand All @@ -572,27 +597,11 @@ func (d *BaseDispatcher) onNextStage() (err error) {
return nil
}

func (d *BaseDispatcher) dispatchSubTask(subtaskStep proto.Step, metas [][]byte) error {
func (d *BaseDispatcher) dispatchSubTask(
subtaskStep proto.Step,
metas [][]byte,
serverNodes []*infosync.ServerInfo) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.Stringer("state", d.Task.State), zap.Int64("step", int64(d.Task.Step)), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))

// select all available TiDB nodes for task.
serverNodes, err := d.GetEligibleInstances(d.ctx, d.Task)
logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes)))

if err != nil {
return err
}
// 4. filter by role.
serverNodes, err = d.filterByRole(serverNodes)
if err != nil {
return err
}

logutil.Logger(d.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes)))

if len(serverNodes) == 0 {
return errors.New("no available TiDB node to dispatch subtasks")
}
d.taskNodes = make([]string, len(serverNodes))
for i := range serverNodes {
d.taskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port)
Expand All @@ -619,8 +628,14 @@ func (d *BaseDispatcher) handlePlanErr(err error) error {
return d.updateTask(proto.TaskStateFailed, nil, RetrySQLTimes)
}

// MockServerInfo exported for dispatcher_test.go
var MockServerInfo []*infosync.ServerInfo

// GenerateSchedulerNodes generate a eligible TiDB nodes.
func GenerateSchedulerNodes(ctx context.Context) (serverNodes []*infosync.ServerInfo, err error) {
failpoint.Inject("mockSchedulerNodes", func() {
failpoint.Return(MockServerInfo, nil)
})
var serverInfos map[string]*infosync.ServerInfo
_, etcd := ctx.Value("etcd").(bool)
if intest.InTest && !etcd {
Expand Down Expand Up @@ -668,7 +683,9 @@ func (d *BaseDispatcher) filterByRole(infos []*infosync.ServerInfo) ([]*infosync

// GetAllSchedulerIDs gets all the scheduler IDs.
func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) ([]string, error) {
serverInfos, err := d.GetEligibleInstances(ctx, task)
// We get all servers instead of eligible servers here
// because eligible servers may change during the task execution.
serverInfos, err := GenerateSchedulerNodes(ctx)
if err != nil {
return nil, err
}
Expand Down
18 changes: 10 additions & 8 deletions pkg/disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type testDispatcherExt struct{}
func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
return nil, nil
}

Expand All @@ -61,8 +61,8 @@ func (*testDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle,

var mockedAllServerInfos = []*infosync.ServerInfo{}

func (*testDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return mockedAllServerInfos, nil
func (*testDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return mockedAllServerInfos, true, nil
}

func (*testDispatcherExt) IsRetryableErr(error) bool {
Expand All @@ -78,7 +78,7 @@ type numberExampleDispatcherExt struct{}
func (*numberExampleDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
switch task.Step {
case proto.StepInit:
for i := 0; i < subtaskCnt; i++ {
Expand All @@ -99,8 +99,9 @@ func (n *numberExampleDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.
return nil, nil
}

func (*numberExampleDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return dispatcher.GenerateSchedulerNodes(ctx)
func (*numberExampleDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
serverInfo, err := dispatcher.GenerateSchedulerNodes(ctx)
return serverInfo, true, err
}

func (*numberExampleDispatcherExt) IsRetryableErr(error) bool {
Expand Down Expand Up @@ -157,7 +158,7 @@ func TestGetInstance(t *testing.T) {
return gtk.Session(), nil
}, 1, 1, time.Second)
defer pool.Close()

require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/dispatcher/mockSchedulerNodes", "return()"))
dspManager, mgr := MockDispatcherManager(t, pool)
// test no server
task := &proto.Task{ID: 1, Type: proto.TaskTypeExample}
Expand All @@ -173,7 +174,7 @@ func TestGetInstance(t *testing.T) {
uuids := []string{"ddl_id_1", "ddl_id_2"}
serverIDs := []string{"10.123.124.10:32457", "[ABCD:EF01:2345:6789:ABCD:EF01:2345:6789]:65535"}

mockedAllServerInfos = []*infosync.ServerInfo{
dispatcher.MockServerInfo = []*infosync.ServerInfo{
{
ID: uuids[0],
IP: "10.123.124.10",
Expand Down Expand Up @@ -214,6 +215,7 @@ func TestGetInstance(t *testing.T) {
require.NoError(t, err)
require.Len(t, instanceIDs, len(serverIDs))
require.ElementsMatch(t, instanceIDs, serverIDs)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/dispatcher/mockSchedulerNodes"))
}

func TestTaskFailInManager(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/disttask/framework/dispatcher/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type Extension interface {
// 1. task is pending and entering it's first step.
// 2. subtasks dispatched has all finished with no error.
// when next step is StepDone, it should return nil, nil.
OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, step proto.Step) (subtaskMetas [][]byte, err error)
OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, serverInfo []*infosync.ServerInfo, step proto.Step) (subtaskMetas [][]byte, err error)

// OnErrStage is called when:
// 1. subtask is finished with error.
Expand All @@ -73,7 +73,8 @@ type Extension interface {

// GetEligibleInstances is used to get the eligible instances for the task.
// on certain condition we may want to use some instances to do the task, such as instances with more disk.
GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, error)
// The bool return value indicates whether filter instances by role.
GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, bool, error)

// IsRetryableErr is used to check whether the error occurred in dispatcher is retryable.
IsRetryableErr(err error) bool
Expand Down
4 changes: 2 additions & 2 deletions pkg/disttask/framework/framework_dynamic_dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ dispatcher.Extension = (*testDynamicDispatcherExt)(nil)

func (*testDynamicDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {}

func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
// step1
if gTask.Step == proto.StepInit {
dsp.cnt++
Expand Down Expand Up @@ -72,7 +72,7 @@ func (dsp *testDynamicDispatcherExt) GetNextStep(task *proto.Task) proto.Step {
}
}

func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/disttask/framework/framework_err_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ var (
func (*planErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
if p.callTime == 0 {
p.callTime++
Expand Down Expand Up @@ -70,7 +70,7 @@ func (p *planErrDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHa
return []byte("planErrTask"), nil
}

func (*planErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*planErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand All @@ -96,15 +96,15 @@ type planNotRetryableErrDispatcherExt struct {
func (*planNotRetryableErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
return nil, errors.New("not retryable err")
}

func (*planNotRetryableErrDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) {
return nil, errors.New("not retryable err")
}

func (*planNotRetryableErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*planNotRetryableErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/disttask/framework/framework_ha_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ dispatcher.Extension = (*haTestDispatcherExt)(nil)
func (*haTestDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 10
return [][]byte{
Expand Down Expand Up @@ -70,7 +70,7 @@ func (*haTestDispatcherExt) OnErrStage(ctx context.Context, h dispatcher.TaskHan
return nil, nil
}

func (*haTestDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*haTestDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand Down
Loading

0 comments on commit d4618d4

Please sign in to comment.