diff --git a/pkg/executor/sortexec/BUILD.bazel b/pkg/executor/sortexec/BUILD.bazel index 960d3bee40831..9800055c4ebbc 100644 --- a/pkg/executor/sortexec/BUILD.bazel +++ b/pkg/executor/sortexec/BUILD.bazel @@ -27,12 +27,14 @@ go_library( "//pkg/util", "//pkg/util/channel", "//pkg/util/chunk", + "//pkg/util/dbterror/exeerrors", "//pkg/util/disk", "//pkg/util/logutil", "//pkg/util/memory", "//pkg/util/sqlkiller", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_stretchr_testify//require", "@org_uber_go_zap//:zap", ], ) @@ -42,7 +44,7 @@ go_test( timeout = "short", srcs = ["sort_test.go"], flaky = True, - shard_count = 16, + shard_count = 17, deps = [ "//pkg/config", "//pkg/sessionctx/variable", diff --git a/pkg/executor/sortexec/parallel_sort_spill_helper.go b/pkg/executor/sortexec/parallel_sort_spill_helper.go index 77b2754d24486..a2f477c79fa81 100644 --- a/pkg/executor/sortexec/parallel_sort_spill_helper.go +++ b/pkg/executor/sortexec/parallel_sort_spill_helper.go @@ -108,6 +108,12 @@ func (p *parallelSortSpillHelper) spill() (err error) { } }() + p.setInSpilling() + + // Spill is done, broadcast to wake up all sleep goroutines + defer p.cond.Broadcast() + defer p.setNotSpilled() + select { case <-p.finishCh: return nil @@ -138,11 +144,6 @@ func (p *parallelSortSpillHelper) spill() (err error) { } workerWaiter.Wait() - p.setInSpilling() - - // Spill is done, broadcast to wake up all sleep goroutines - defer p.cond.Broadcast() - defer p.setNotSpilled() totalRows := 0 for i := range sortedRowsIters { diff --git a/pkg/executor/sortexec/parallel_sort_worker.go b/pkg/executor/sortexec/parallel_sort_worker.go index 1d7a4d7aaa7b6..9fcdaa729b3ee 100644 --- a/pkg/executor/sortexec/parallel_sort_worker.go +++ b/pkg/executor/sortexec/parallel_sort_worker.go @@ -20,9 +20,11 @@ import ( "sync" "time" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sqlkiller" ) // SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered. @@ -50,6 +52,8 @@ type parallelSortWorker struct { chunkIters []*chunk.Iterator4Chunk rowNumInChunkIters int merger *multiWayMerger + + sqlKiller *sqlkiller.SQLKiller } func newParallelSortWorker( @@ -62,7 +66,9 @@ func newParallelSortWorker( memTracker *memory.Tracker, sortedRowsIter *chunk.Iterator4Slice, maxChunkSize int, - spillHelper *parallelSortSpillHelper) *parallelSortWorker { + spillHelper *parallelSortSpillHelper, + sqlKiller *sqlkiller.SQLKiller, +) *parallelSortWorker { return ¶llelSortWorker{ workerIDForTest: workerIDForTest, lessRowFunc: lessRowFunc, @@ -75,6 +81,7 @@ func newParallelSortWorker( sortedRowsIter: sortedRowsIter, maxSortedRowsLimit: maxChunkSize * 30, spillHelper: spillHelper, + sqlKiller: sqlKiller, } } @@ -112,7 +119,32 @@ func (p *parallelSortWorker) multiWayMergeLocalSortedRows() ([]chunk.Row, error) return nil, err } + loopCnt := uint64(0) + for { + var err error + failpoint.Inject("ParallelSortRandomFail", func(val failpoint.Value) { + if val.(bool) { + randNum := rand.Int31n(10000) + if randNum < 2 { + err = errors.NewNoStackErrorf("failpoint error") + } + } + }) + + if err != nil { + return nil, err + } + + if loopCnt%100 == 0 && p.sqlKiller != nil { + err := p.sqlKiller.HandleSignal() + if err != nil { + return nil, err + } + } + + loopCnt++ + // It's impossible to return error here as rows are in memory row, _ := p.merger.next() if row.IsEmpty() { @@ -202,9 +234,12 @@ func (p *parallelSortWorker) fetchChunksAndSortImpl() bool { } func (p *parallelSortWorker) keyColumnsLess(i, j chunk.Row) int { - if p.timesOfRowCompare >= SignalCheckpointForSort { - // Trigger Consume for checking the NeedKill signal - p.memTracker.Consume(1) + if p.timesOfRowCompare >= SignalCheckpointForSort && p.sqlKiller != nil { + err := p.sqlKiller.HandleSignal() + if err != nil { + panic(err) + } + p.timesOfRowCompare = 0 } diff --git a/pkg/executor/sortexec/sort.go b/pkg/executor/sortexec/sort.go index 5ce67e68648c7..edfb02c6ed2c3 100644 --- a/pkg/executor/sortexec/sort.go +++ b/pkg/executor/sortexec/sort.go @@ -667,7 +667,7 @@ func (e *SortExec) fetchChunksParallel(ctx context.Context) error { fetcherWaiter := util.WaitGroupWrapper{} for i := range e.Parallel.workers { - e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper) + e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper, &e.Ctx().GetSessionVars().SQLKiller) worker := e.Parallel.workers[i] workersWaiter.Run(func() { worker.run() diff --git a/pkg/executor/sortexec/sort_spill.go b/pkg/executor/sortexec/sort_spill.go index 925e4f48df1c7..167d24e69e339 100644 --- a/pkg/executor/sortexec/sort_spill.go +++ b/pkg/executor/sortexec/sort_spill.go @@ -115,10 +115,6 @@ func (s *parallelSortSpillAction) actionImpl(t *memory.Tracker) bool { } if t.CheckExceed() && s.spillHelper.isNotSpilledNoLock() && hasEnoughDataToSpill(s.spillHelper.sortExec.memTracker, t) { - // Ideally, all goroutines entering this action should wait for the finish of spill once - // spill is triggered(we consider spill is triggered when the `needSpill` has been set). - // However, out of some reasons, we have to directly return before the finish of - // sort operation executed in spill as sort will retrigger the action and lead to dead lock. s.spillHelper.setNeedSpillNoLock() s.spillHelper.bytesConsumed.Store(t.BytesConsumed()) s.spillHelper.bytesLimit.Store(t.GetBytesLimit()) diff --git a/pkg/executor/sortexec/topn.go b/pkg/executor/sortexec/topn.go index d302fd03d4d3f..68e6fe983e973 100644 --- a/pkg/executor/sortexec/topn.go +++ b/pkg/executor/sortexec/topn.go @@ -109,11 +109,12 @@ func (e *TopNExec) Open(ctx context.Context) error { exec.RetTypes(e), workers, e.Concurrency, + &e.Ctx().GetSessionVars().SQLKiller, ) e.spillAction = &topNSpillAction{spillHelper: e.spillHelper} e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.spillAction) } else { - e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0) + e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0, nil) } return exec.Open(ctx, e.Children(0)) diff --git a/pkg/executor/sortexec/topn_chunk_heap.go b/pkg/executor/sortexec/topn_chunk_heap.go index df19763e4693a..4d791b77372ac 100644 --- a/pkg/executor/sortexec/topn_chunk_heap.go +++ b/pkg/executor/sortexec/topn_chunk_heap.go @@ -16,11 +16,16 @@ package sortexec import ( "container/heap" + "context" + "testing" "github.com/pingcap/tidb/pkg/executor/internal/exec" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sqlkiller" + "github.com/stretchr/testify/require" ) // topNChunkHeap implements heap.Interface. @@ -153,3 +158,22 @@ func (h *topNChunkHeap) Pop() any { func (h *topNChunkHeap) Swap(i, j int) { h.rowPtrs[i], h.rowPtrs[j] = h.rowPtrs[j], h.rowPtrs[i] } + +// TestKillSignalInTopN is for test +func TestKillSignalInTopN(t *testing.T, topnExec *TopNExec) { + ctx := context.Background() + err := topnExec.Open(ctx) + require.NoError(t, err) + + chkHeap := &topNChunkHeap{} + // Offset of heap in worker should be 0, as we need to spill all data + chkHeap.init(topnExec, topnExec.memTracker, topnExec.Limit.Offset+topnExec.Limit.Count, 0, topnExec.greaterRow, topnExec.RetFieldTypes()) + srcChk := exec.TryNewCacheChunk(topnExec.Children(0)) + err = exec.Next(ctx, topnExec.Children(0), srcChk) + require.NoError(t, err) + chkHeap.rowChunks.Add(srcChk) + + topnExec.Ctx().GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted) + err = topnExec.spillHelper.spillHeap(chkHeap) + require.Error(t, err, exeerrors.ErrQueryInterrupted.GenWithStackByArgs()) +} diff --git a/pkg/executor/sortexec/topn_spill.go b/pkg/executor/sortexec/topn_spill.go index 4c3aa27957be6..d72d7bcec1342 100644 --- a/pkg/executor/sortexec/topn_spill.go +++ b/pkg/executor/sortexec/topn_spill.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/pkg/util/disk" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sqlkiller" "go.uber.org/zap" ) @@ -47,6 +48,8 @@ type topNSpillHelper struct { bytesConsumed atomic.Int64 bytesLimit atomic.Int64 + + sqlKiller *sqlkiller.SQLKiller } func newTopNSpillerHelper( @@ -58,6 +61,7 @@ func newTopNSpillerHelper( fieldTypes []*types.FieldType, workers []*topNWorker, concurrencyNum int, + sqlKiller *sqlkiller.SQLKiller, ) *topNSpillHelper { lock := sync.Mutex{} tmpSpillChunksChan := make(chan *chunk.Chunk, concurrencyNum) @@ -78,6 +82,7 @@ func newTopNSpillerHelper( workers: workers, bytesConsumed: atomic.Int64{}, bytesLimit: atomic.Int64{}, + sqlKiller: sqlKiller, } } @@ -209,6 +214,13 @@ func (t *topNSpillHelper) spillHeap(chkHeap *topNChunkHeap) error { rowPtrNum := chkHeap.Len() for ; chkHeap.idx < rowPtrNum; chkHeap.idx++ { + if chkHeap.idx%100 == 0 && t.sqlKiller != nil { + err := t.sqlKiller.HandleSignal() + if err != nil { + return err + } + } + if tmpSpillChunk.IsFull() { err := t.spillTmpSpillChunk(inDisk, tmpSpillChunk) if err != nil { diff --git a/pkg/executor/sortexec/topn_spill_test.go b/pkg/executor/sortexec/topn_spill_test.go index 159e33e7c0876..3e6378b485ae5 100644 --- a/pkg/executor/sortexec/topn_spill_test.go +++ b/pkg/executor/sortexec/topn_spill_test.go @@ -493,6 +493,28 @@ func TestIssue54206(t *testing.T) { tk.MustQuery("select t1.a+t1.b as result from t1 left join t2 on 1 = 0 order by result limit 1;") } +func TestIssue54541(t *testing.T) { + totalRowNum := 30 + sortexec.SetSmallSpillChunkSizeForTest() + ctx := mock.NewContext() + topNCase := &testutil.SortCase{Rows: totalRowNum, OrderByIdx: []int{0, 1}, Ndvs: []int{0, 0}, Ctx: ctx} + + ctx.GetSessionVars().InitChunkSize = 32 + ctx.GetSessionVars().MaxChunkSize = 32 + ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2) + ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) + ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) + + offset := uint64(totalRowNum / 10) + count := uint64(totalRowNum / 3) + + schema := expression.NewSchema(topNCase.Columns()...) + dataSource := buildDataSource(topNCase, schema) + exe := buildTopNExec(topNCase, dataSource, offset, count) + + sortexec.TestKillSignalInTopN(t, exe) +} + func TestTopNFallBackAction(t *testing.T) { sortexec.SetSmallSpillChunkSizeForTest() ctx := mock.NewContext()