Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disttask: fix wrong plan when filterByRole #48366

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions pkg/ddl/backfilling_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (dsp *BackfillingDispatcherExt) OnNextSubtasksBatch(
ctx context.Context,
taskHandle dispatcher.TaskHandle,
gTask *proto.Task,
serverInfo []*infosync.ServerInfo,
nextStep proto.Step,
) (taskMeta [][]byte, err error) {
logger := logutil.BgLogger().With(
Expand All @@ -95,12 +96,7 @@ func (dsp *BackfillingDispatcherExt) OnNextSubtasksBatch(
if tblInfo.Partition != nil {
return generatePartitionPlan(tblInfo)
}
is, err := dsp.GetEligibleInstances(ctx, gTask)
if err != nil {
return nil, err
}
instanceCnt := len(is)
return generateNonPartitionPlan(dsp.d, tblInfo, job, dsp.GlobalSort, instanceCnt)
return generateNonPartitionPlan(dsp.d, tblInfo, job, dsp.GlobalSort, len(serverInfo))
case StepMergeSort:
res, err := generateMergePlan(taskHandle, gTask, logger)
if err != nil {
Expand Down Expand Up @@ -195,12 +191,12 @@ func (*BackfillingDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.Task
}

// GetEligibleInstances implements dispatcher.Extension interface.
func (*BackfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*BackfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
serverInfos, err := dispatcher.GenerateSchedulerNodes(ctx)
if err != nil {
return nil, err
return nil, true, err
}
return serverInfos, nil
return serverInfos, true, nil
}

// IsRetryableErr implements dispatcher.Extension.IsRetryableErr interface.
Expand Down
22 changes: 13 additions & 9 deletions pkg/ddl/backfilling_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
// 1.1 OnNextSubtasksBatch
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, ddl.StepReadIndex, gTask.Step)
metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
serverInfos, _, err := dsp.GetEligibleInstances(context.Background(), gTask)
require.NoError(t, err)
metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, len(tblInfo.Partition.Definitions), len(metas))
for i, par := range tblInfo.Partition.Definitions {
Expand All @@ -81,7 +83,7 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
gTask.State = proto.TaskStateRunning
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepDone, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Len(t, metas, 0)

Expand All @@ -98,7 +100,7 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
// 2.1 empty table
tk.MustExec("create table t1(id int primary key, v int)")
gTask = createAddIndexGlobalTask(t, dom, "test", "t1", proto.Backfill, false)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
// 2.2 non empty table.
Expand All @@ -110,15 +112,15 @@ func TestBackfillingDispatcherLocalMode(t *testing.T) {
gTask = createAddIndexGlobalTask(t, dom, "test", "t2", proto.Backfill, false)
// 2.2.1 stepInit
gTask.Step = dsp.GetNextStep(gTask)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
require.Equal(t, ddl.StepReadIndex, gTask.Step)
// 2.2.2 StepReadIndex
gTask.State = proto.TaskStateRunning
gTask.Step = dsp.GetNextStep(gTask)
require.Equal(t, proto.StepDone, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, serverInfos, gTask.Step)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
}
Expand Down Expand Up @@ -154,9 +156,11 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
taskID, err := mgr.AddNewGlobalTask(task.Key, proto.Backfill, 1, task.Meta)
require.NoError(t, err)
task.ID = taskID
serverInfos, _, err := dsp.GetEligibleInstances(context.Background(), task)
require.NoError(t, err)

// 1. to read-index stage
subtaskMetas, err := dsp.OnNextSubtasksBatch(ctx, dsp, task, dsp.GetNextStep(task))
subtaskMetas, err := dsp.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, dsp.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
task.Step = ext.GetNextStep(task)
Expand Down Expand Up @@ -197,7 +201,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/forceMergeSort"))
})
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, ext.GetNextStep(task))
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, ext.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
task.Step = ext.GetNextStep(task)
Expand Down Expand Up @@ -236,13 +240,13 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/mockWriteIngest"))
})
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, ext.GetNextStep(task))
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, ext.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
task.Step = ext.GetNextStep(task)
require.Equal(t, ddl.StepWriteAndIngest, task.Step)
// 4. to done stage.
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, ext.GetNextStep(task))
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, dsp, task, serverInfos, ext.GetNextStep(task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 0)
task.Step = ext.GetNextStep(task)
Expand Down
65 changes: 41 additions & 24 deletions pkg/disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,17 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
if err != nil {
return err
}
eligibleServerInfos, err := d.GetEligibleInstances(d.ctx, d.Task)

eligibleServerInfos, filter, err := d.GetEligibleInstances(d.ctx, d.Task)
if err != nil {
return err
}
if filter {
eligibleServerInfos, err = d.filterByRole(eligibleServerInfos)
if err != nil {
return err
}
}
newInfos := serverInfos[:0]
for _, m := range serverInfos {
found := false
Expand Down Expand Up @@ -545,7 +552,25 @@ func (d *BaseDispatcher) onNextStage() (err error) {

for {
// 3. generate a batch of subtasks.
metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task, nextStep)
/// select all available TiDB nodes for task.
serverNodes, filter, err := d.GetEligibleInstances(d.ctx, d.Task)
logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes)))

if err != nil {
return err
}
if filter {
serverNodes, err = d.filterByRole(serverNodes)
if err != nil {
return err
}
}
logutil.Logger(d.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes)))
if len(serverNodes) == 0 {
return errors.New("no available TiDB node to dispatch subtasks")
}

metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task, serverNodes, nextStep)
if err != nil {
logutil.Logger(d.logCtx).Warn("generate part of subtasks failed", zap.Error(err))
return d.handlePlanErr(err)
Expand All @@ -556,7 +581,7 @@ func (d *BaseDispatcher) onNextStage() (err error) {
})

// 4. dispatch batch of subtasks to EligibleInstances.
err = d.dispatchSubTask(nextStep, metas)
err = d.dispatchSubTask(nextStep, metas, serverNodes)
if err != nil {
return err
}
Expand All @@ -572,27 +597,11 @@ func (d *BaseDispatcher) onNextStage() (err error) {
return nil
}

func (d *BaseDispatcher) dispatchSubTask(subtaskStep proto.Step, metas [][]byte) error {
func (d *BaseDispatcher) dispatchSubTask(
subtaskStep proto.Step,
metas [][]byte,
serverNodes []*infosync.ServerInfo) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.Stringer("state", d.Task.State), zap.Int64("step", int64(d.Task.Step)), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))

// select all available TiDB nodes for task.
serverNodes, err := d.GetEligibleInstances(d.ctx, d.Task)
logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes)))

if err != nil {
return err
}
// 4. filter by role.
serverNodes, err = d.filterByRole(serverNodes)
if err != nil {
return err
}

logutil.Logger(d.logCtx).Info("eligible instances", zap.Int("num", len(serverNodes)))

if len(serverNodes) == 0 {
return errors.New("no available TiDB node to dispatch subtasks")
}
d.taskNodes = make([]string, len(serverNodes))
for i := range serverNodes {
d.taskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port)
Expand All @@ -619,8 +628,14 @@ func (d *BaseDispatcher) handlePlanErr(err error) error {
return d.updateTask(proto.TaskStateFailed, nil, RetrySQLTimes)
}

// MockServerInfo exported for dispatcher_test.go
var MockServerInfo []*infosync.ServerInfo

// GenerateSchedulerNodes generate a eligible TiDB nodes.
func GenerateSchedulerNodes(ctx context.Context) (serverNodes []*infosync.ServerInfo, err error) {
failpoint.Inject("mockSchedulerNodes", func() {
failpoint.Return(MockServerInfo, nil)
})
var serverInfos map[string]*infosync.ServerInfo
_, etcd := ctx.Value("etcd").(bool)
if intest.InTest && !etcd {
Expand Down Expand Up @@ -668,7 +683,9 @@ func (d *BaseDispatcher) filterByRole(infos []*infosync.ServerInfo) ([]*infosync

// GetAllSchedulerIDs gets all the scheduler IDs.
func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) ([]string, error) {
serverInfos, err := d.GetEligibleInstances(ctx, task)
// We get all servers instead of eligible servers here
// because eligible servers may change during the task execution.
serverInfos, err := GenerateSchedulerNodes(ctx)
if err != nil {
return nil, err
}
Expand Down
18 changes: 10 additions & 8 deletions pkg/disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type testDispatcherExt struct{}
func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
return nil, nil
}

Expand All @@ -61,8 +61,8 @@ func (*testDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle,

var mockedAllServerInfos = []*infosync.ServerInfo{}

func (*testDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return mockedAllServerInfos, nil
func (*testDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return mockedAllServerInfos, true, nil
}

func (*testDispatcherExt) IsRetryableErr(error) bool {
Expand All @@ -78,7 +78,7 @@ type numberExampleDispatcherExt struct{}
func (*numberExampleDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
switch task.Step {
case proto.StepInit:
for i := 0; i < subtaskCnt; i++ {
Expand All @@ -99,8 +99,9 @@ func (n *numberExampleDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.
return nil, nil
}

func (*numberExampleDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return dispatcher.GenerateSchedulerNodes(ctx)
func (*numberExampleDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
serverInfo, err := dispatcher.GenerateSchedulerNodes(ctx)
return serverInfo, true, err
}

func (*numberExampleDispatcherExt) IsRetryableErr(error) bool {
Expand Down Expand Up @@ -157,7 +158,7 @@ func TestGetInstance(t *testing.T) {
return gtk.Session(), nil
}, 1, 1, time.Second)
defer pool.Close()

require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/dispatcher/mockSchedulerNodes", "return()"))
dspManager, mgr := MockDispatcherManager(t, pool)
// test no server
task := &proto.Task{ID: 1, Type: proto.TaskTypeExample}
Expand All @@ -173,7 +174,7 @@ func TestGetInstance(t *testing.T) {
uuids := []string{"ddl_id_1", "ddl_id_2"}
serverIDs := []string{"10.123.124.10:32457", "[ABCD:EF01:2345:6789:ABCD:EF01:2345:6789]:65535"}

mockedAllServerInfos = []*infosync.ServerInfo{
dispatcher.MockServerInfo = []*infosync.ServerInfo{
{
ID: uuids[0],
IP: "10.123.124.10",
Expand Down Expand Up @@ -214,6 +215,7 @@ func TestGetInstance(t *testing.T) {
require.NoError(t, err)
require.Len(t, instanceIDs, len(serverIDs))
require.ElementsMatch(t, instanceIDs, serverIDs)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/dispatcher/mockSchedulerNodes"))
}

func TestTaskFailInManager(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/disttask/framework/dispatcher/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type Extension interface {
// 1. task is pending and entering it's first step.
// 2. subtasks dispatched has all finished with no error.
// when next step is StepDone, it should return nil, nil.
OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, step proto.Step) (subtaskMetas [][]byte, err error)
OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, serverInfo []*infosync.ServerInfo, step proto.Step) (subtaskMetas [][]byte, err error)

// OnErrStage is called when:
// 1. subtask is finished with error.
Expand All @@ -73,7 +73,8 @@ type Extension interface {

// GetEligibleInstances is used to get the eligible instances for the task.
// on certain condition we may want to use some instances to do the task, such as instances with more disk.
GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, error)
// The bool return value indicates whether filter instances by role.
GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, bool, error)

// IsRetryableErr is used to check whether the error occurred in dispatcher is retryable.
IsRetryableErr(err error) bool
Expand Down
4 changes: 2 additions & 2 deletions pkg/disttask/framework/framework_dynamic_dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ dispatcher.Extension = (*testDynamicDispatcherExt)(nil)

func (*testDynamicDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {}

func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
// step1
if gTask.Step == proto.StepInit {
dsp.cnt++
Expand Down Expand Up @@ -72,7 +72,7 @@ func (dsp *testDynamicDispatcherExt) GetNextStep(task *proto.Task) proto.Step {
}
}

func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/disttask/framework/framework_err_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ var (
func (*planErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
if p.callTime == 0 {
p.callTime++
Expand Down Expand Up @@ -70,7 +70,7 @@ func (p *planErrDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHa
return []byte("planErrTask"), nil
}

func (*planErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*planErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand All @@ -96,15 +96,15 @@ type planNotRetryableErrDispatcherExt struct {
func (*planNotRetryableErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
return nil, errors.New("not retryable err")
}

func (*planNotRetryableErrDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) {
return nil, errors.New("not retryable err")
}

func (*planNotRetryableErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*planNotRetryableErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/disttask/framework/framework_ha_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ dispatcher.Extension = (*haTestDispatcherExt)(nil)
func (*haTestDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ proto.Step) (metas [][]byte, err error) {
func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ []*infosync.ServerInfo, _ proto.Step) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 10
return [][]byte{
Expand Down Expand Up @@ -70,7 +70,7 @@ func (*haTestDispatcherExt) OnErrStage(ctx context.Context, h dispatcher.TaskHan
return nil, nil
}

func (*haTestDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*haTestDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, bool, error) {
return generateSchedulerNodes4Test()
}

Expand Down
Loading