diff --git a/executor/index_merge_reader.go b/executor/index_merge_reader.go index 97cdb0ba5821d..610c15b809bfa 100644 --- a/executor/index_merge_reader.go +++ b/executor/index_merge_reader.go @@ -294,7 +294,7 @@ func (e *IndexMergeReaderExecutor) startIndexMergeProcessWorker(ctx context.Cont func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { // Wait for processWorker to close resultCh. - time.Sleep(2) + time.Sleep(time.Second * 2) // Should use fetchCh instead of resultCh to send error. syncErr(ctx, e.finished, fetchCh, errors.New("testIndexMergeResultChCloseEarly")) }) @@ -375,8 +375,10 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, for parTblIdx, keyRange := range keyRanges { // check if this executor is closed select { + case <-ctx.Done(): + return case <-e.finished: - break + return default: } @@ -499,8 +501,10 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, for parTblIdx, tbl := range tbls { // check if this executor is closed select { + case <-ctx.Done(): + return case <-e.finished: - break + return default: } @@ -751,6 +755,12 @@ func (e *IndexMergeReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) e } func (e *IndexMergeReaderExecutor) getResultTask() (*indexMergeTableTask, error) { + failpoint.Inject("testIndexMergeMainReturnEarly", func(_ failpoint.Value) { + // To make sure processWorker make resultCh to be full. + // When main goroutine close finished, processWorker may be stuck when writing resultCh. + time.Sleep(time.Second * 20) + failpoint.Return(nil, errors.New("failpoint testIndexMergeMainReturnEarly")) + }) if e.resultCurr != nil && e.resultCurr.cursor < len(e.resultCurr.rows) { return e.resultCurr, nil } @@ -778,6 +788,7 @@ func handleWorkerPanic(ctx context.Context, finished <-chan struct{}, ch chan<- defer close(ch) } if r == nil { + logutil.BgLogger().Info("worker finish without panic", zap.Any("worker", worker)) return } @@ -840,7 +851,20 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- failpoint.Inject("testIndexMergePanicProcessWorkerUnion", nil) distinctHandles := make(map[int64]*kv.HandleMap) - for task := range fetchCh { + for { + var ok bool + var task *indexMergeTableTask + select { + case <-ctx.Done(): + return + case <-finished: + return + case task, ok = <-fetchCh: + if !ok { + return + } + } + select { case err := <-task.doneCh: // If got error from partialIndexWorker/partialTableWorker, stop processing. @@ -876,7 +900,7 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- if len(fhs) == 0 { continue } - task := &indexMergeTableTask{ + task = &indexMergeTableTask{ lookupTableTask: lookupTableTask{ handles: fhs, doneCh: make(chan error, 1), @@ -887,13 +911,27 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- if w.stats != nil { w.stats.IndexMergeProcess += time.Since(start) } + failpoint.Inject("testIndexMergeProcessWorkerUnionHang", func(_ failpoint.Value) { + for i := 0; i < cap(resultCh); i++ { + select { + case resultCh <- &indexMergeTableTask{}: + default: + } + } + }) select { case <-ctx.Done(): return case <-finished: return case workCh <- task: - resultCh <- task + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + } } } } @@ -992,6 +1030,14 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte zap.Int("parTblIdx", parTblIdx), zap.Int("task.handles", len(task.handles))) } } + failpoint.Inject("testIndexMergeProcessWorkerIntersectionHang", func(_ failpoint.Value) { + for i := 0; i < cap(resultCh); i++ { + select { + case resultCh <- &indexMergeTableTask{}: + default: + } + } + }) for _, task := range tasks { select { case <-ctx.Done(): @@ -999,7 +1045,13 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte case <-finished: return case workCh <- task: - resultCh <- task + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + } } } } @@ -1058,29 +1110,47 @@ func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fet }, handleWorkerPanic(ctx, finished, resultCh, errCh, partTblIntersectionWorkerType)) workers = append(workers, worker) } -loop: - for task := range fetchCh { + defer func() { + for _, processWorker := range workers { + close(processWorker.workerCh) + } + wg.Wait() + }() + for { + var ok bool + var task *indexMergeTableTask + select { + case <-ctx.Done(): + return + case <-finished: + return + case task, ok = <-fetchCh: + if !ok { + return + } + } + select { case err := <-task.doneCh: // If got error from partialIndexWorker/partialTableWorker, stop processing. if err != nil { syncErr(ctx, finished, resultCh, err) - break loop + return } default: } select { + case <-ctx.Done(): + return + case <-finished: + return case workers[task.parTblIdx%workerCnt].workerCh <- task: case <-errCh: // If got error from intersectionProcessWorker, stop processing. - break loop + return } } - for _, processWorker := range workers { - close(processWorker.workerCh) - } - wg.Wait() } type partialIndexWorker struct { @@ -1229,12 +1299,14 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** for { waitStart := time.Now() select { + case <-ctx.Done(): + return + case <-w.finished: + return case *task, ok = <-w.workCh: if !ok { return } - case <-w.finished: - return } // Make sure panic failpoint is after fetch task from workCh. // Otherwise cannot send error to task.doneCh. @@ -1255,13 +1327,20 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** atomic.AddInt64(&w.stats.TableTaskNum, 1) } failpoint.Inject("testIndexMergePickAndExecTaskPanic", nil) - (*task).doneCh <- err + select { + case <-ctx.Done(): + return + case <-w.finished: + return + case (*task).doneCh <- err: + } } } func (w *indexMergeTableScanWorker) handleTableScanWorkerPanic(ctx context.Context, finished <-chan struct{}, task **indexMergeTableTask, worker string) func(r interface{}) { return func(r interface{}) { if r == nil { + logutil.BgLogger().Info("worker finish without panic", zap.Any("worker", worker)) return } diff --git a/executor/index_merge_reader_test.go b/executor/index_merge_reader_test.go index d30fce71a180e..46a1206460074 100644 --- a/executor/index_merge_reader_test.go +++ b/executor/index_merge_reader_test.go @@ -798,6 +798,49 @@ func TestIntersectionMemQuota(t *testing.T) { require.Contains(t, err.Error(), "Out Of Memory Quota!") } +func setupPartitionTableHelper(tk *testkit.TestKit) { + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int, c2 bigint, c3 bigint, primary key(c1), key(c2), key(c3));") + insertStr := "insert into t1 values(0, 0, 0)" + for i := 1; i < 1000; i++ { + insertStr += fmt.Sprintf(", (%d, %d, %d)", i, i, i) + } + tk.MustExec(insertStr) + tk.MustExec("analyze table t1;") + tk.MustExec("set tidb_partition_prune_mode = 'dynamic'") +} + +func TestIndexMergeProcessWorkerHang(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + setupPartitionTableHelper(tk) + + var err error + sql := "select /*+ use_index_merge(t1) */ c1 from t1 where c1 < 900 or c2 < 1000;" + res := tk.MustQuery("explain " + sql).Rows() + require.Contains(t, res[1][0], "IndexMerge") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerUnionHang", "return(true)")) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergeMainReturnEarly") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerUnionHang")) + + sql = "select /*+ use_index_merge(t1, c2, c3) */ c1 from t1 where c2 < 900 and c3 < 1000;" + res = tk.MustQuery("explain " + sql).Rows() + require.Contains(t, res[1][0], "IndexMerge") + require.Contains(t, res[1][4], "intersection") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerIntersectionHang", "return(true)")) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergeMainReturnEarly") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerIntersectionHang")) +} + func TestIndexMergePanic(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) @@ -811,16 +854,7 @@ func TestIndexMergePanic(t *testing.T) { tk.MustExec("select /*+ use_index_merge(t1, primary, c2, c3) */ c1 from t1 where c1 < 100 or c2 < 100") require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeResultChCloseEarly")) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1") - tk.MustExec("create table t1(c1 int, c2 bigint, c3 bigint, primary key(c1), key(c2), key(c3)) partition by hash(c1) partitions 10;") - insertStr := "insert into t1 values(0, 0, 0)" - for i := 1; i < 1000; i++ { - insertStr += fmt.Sprintf(", (%d, %d, %d)", i, i, i) - } - tk.MustExec(insertStr) - tk.MustExec("analyze table t1;") - tk.MustExec("set tidb_partition_prune_mode = 'dynamic'") + setupPartitionTableHelper(tk) minV := 200 maxV := 1000 @@ -885,17 +919,7 @@ func TestIndexMergePanic(t *testing.T) { func TestIndexMergeCoprGoroutinesLeak(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) - - tk.MustExec("use test") - tk.MustExec("drop table if exists t1") - tk.MustExec("create table t1(c1 int, c2 bigint, c3 bigint, primary key(c1), key(c2), key(c3));") - insertStr := "insert into t1 values(0, 0, 0)" - for i := 1; i < 1000; i++ { - insertStr += fmt.Sprintf(", (%d, %d, %d)", i, i, i) - } - tk.MustExec(insertStr) - tk.MustExec("analyze table t1;") - tk.MustExec("set tidb_partition_prune_mode = 'dynamic'") + setupPartitionTableHelper(tk) var err error sql := "select /*+ use_index_merge(t1) */ c1 from t1 where c1 < 900 or c2 < 1000;"