Skip to content

Commit

Permalink
executor: check kill signal for topn and parallel sort spill (#56238)
Browse files Browse the repository at this point in the history
close #54541
  • Loading branch information
xzhangxian1008 authored Oct 8, 2024
1 parent 0ab49be commit 179521a
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 16 deletions.
4 changes: 3 additions & 1 deletion pkg/executor/sortexec/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ go_library(
"//pkg/util",
"//pkg/util/channel",
"//pkg/util/chunk",
"//pkg/util/dbterror/exeerrors",
"//pkg/util/disk",
"//pkg/util/logutil",
"//pkg/util/memory",
"//pkg/util/sqlkiller",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_stretchr_testify//require",
"@org_uber_go_zap//:zap",
],
)
Expand All @@ -42,7 +44,7 @@ go_test(
timeout = "short",
srcs = ["sort_test.go"],
flaky = True,
shard_count = 16,
shard_count = 17,
deps = [
"//pkg/config",
"//pkg/sessionctx/variable",
Expand Down
11 changes: 6 additions & 5 deletions pkg/executor/sortexec/parallel_sort_spill_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ func (p *parallelSortSpillHelper) spill() (err error) {
}
}()

p.setInSpilling()

// Spill is done, broadcast to wake up all sleep goroutines
defer p.cond.Broadcast()
defer p.setNotSpilled()

select {
case <-p.finishCh:
return nil
Expand Down Expand Up @@ -138,11 +144,6 @@ func (p *parallelSortSpillHelper) spill() (err error) {
}

workerWaiter.Wait()
p.setInSpilling()

// Spill is done, broadcast to wake up all sleep goroutines
defer p.cond.Broadcast()
defer p.setNotSpilled()

totalRows := 0
for i := range sortedRowsIters {
Expand Down
43 changes: 39 additions & 4 deletions pkg/executor/sortexec/parallel_sort_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"sync"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
)

// SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered.
Expand Down Expand Up @@ -50,6 +52,8 @@ type parallelSortWorker struct {
chunkIters []*chunk.Iterator4Chunk
rowNumInChunkIters int
merger *multiWayMerger

sqlKiller *sqlkiller.SQLKiller
}

func newParallelSortWorker(
Expand All @@ -62,7 +66,9 @@ func newParallelSortWorker(
memTracker *memory.Tracker,
sortedRowsIter *chunk.Iterator4Slice,
maxChunkSize int,
spillHelper *parallelSortSpillHelper) *parallelSortWorker {
spillHelper *parallelSortSpillHelper,
sqlKiller *sqlkiller.SQLKiller,
) *parallelSortWorker {
return &parallelSortWorker{
workerIDForTest: workerIDForTest,
lessRowFunc: lessRowFunc,
Expand All @@ -75,6 +81,7 @@ func newParallelSortWorker(
sortedRowsIter: sortedRowsIter,
maxSortedRowsLimit: maxChunkSize * 30,
spillHelper: spillHelper,
sqlKiller: sqlKiller,
}
}

Expand Down Expand Up @@ -112,7 +119,32 @@ func (p *parallelSortWorker) multiWayMergeLocalSortedRows() ([]chunk.Row, error)
return nil, err
}

loopCnt := uint64(0)

for {
var err error
failpoint.Inject("ParallelSortRandomFail", func(val failpoint.Value) {
if val.(bool) {
randNum := rand.Int31n(10000)
if randNum < 2 {
err = errors.NewNoStackErrorf("failpoint error")
}
}
})

if err != nil {
return nil, err
}

if loopCnt%100 == 0 && p.sqlKiller != nil {
err := p.sqlKiller.HandleSignal()
if err != nil {
return nil, err
}
}

loopCnt++

// It's impossible to return error here as rows are in memory
row, _ := p.merger.next()
if row.IsEmpty() {
Expand Down Expand Up @@ -202,9 +234,12 @@ func (p *parallelSortWorker) fetchChunksAndSortImpl() bool {
}

func (p *parallelSortWorker) keyColumnsLess(i, j chunk.Row) int {
if p.timesOfRowCompare >= SignalCheckpointForSort {
// Trigger Consume for checking the NeedKill signal
p.memTracker.Consume(1)
if p.timesOfRowCompare >= SignalCheckpointForSort && p.sqlKiller != nil {
err := p.sqlKiller.HandleSignal()
if err != nil {
panic(err)
}

p.timesOfRowCompare = 0
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/sortexec/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ func (e *SortExec) fetchChunksParallel(ctx context.Context) error {
fetcherWaiter := util.WaitGroupWrapper{}

for i := range e.Parallel.workers {
e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper)
e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper, &e.Ctx().GetSessionVars().SQLKiller)
worker := e.Parallel.workers[i]
workersWaiter.Run(func() {
worker.run()
Expand Down
4 changes: 0 additions & 4 deletions pkg/executor/sortexec/sort_spill.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ func (s *parallelSortSpillAction) actionImpl(t *memory.Tracker) bool {
}

if t.CheckExceed() && s.spillHelper.isNotSpilledNoLock() && hasEnoughDataToSpill(s.spillHelper.sortExec.memTracker, t) {
// Ideally, all goroutines entering this action should wait for the finish of spill once
// spill is triggered(we consider spill is triggered when the `needSpill` has been set).
// However, out of some reasons, we have to directly return before the finish of
// sort operation executed in spill as sort will retrigger the action and lead to dead lock.
s.spillHelper.setNeedSpillNoLock()
s.spillHelper.bytesConsumed.Store(t.BytesConsumed())
s.spillHelper.bytesLimit.Store(t.GetBytesLimit())
Expand Down
3 changes: 2 additions & 1 deletion pkg/executor/sortexec/topn.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ func (e *TopNExec) Open(ctx context.Context) error {
exec.RetTypes(e),
workers,
e.Concurrency,
&e.Ctx().GetSessionVars().SQLKiller,
)
e.spillAction = &topNSpillAction{spillHelper: e.spillHelper}
e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.spillAction)
} else {
e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0)
e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0, nil)
}

return exec.Open(ctx, e.Children(0))
Expand Down
24 changes: 24 additions & 0 deletions pkg/executor/sortexec/topn_chunk_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ package sortexec

import (
"container/heap"
"context"
"testing"

"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/dbterror/exeerrors"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
"github.com/stretchr/testify/require"
)

// topNChunkHeap implements heap.Interface.
Expand Down Expand Up @@ -153,3 +158,22 @@ func (h *topNChunkHeap) Pop() any {
func (h *topNChunkHeap) Swap(i, j int) {
h.rowPtrs[i], h.rowPtrs[j] = h.rowPtrs[j], h.rowPtrs[i]
}

// TestKillSignalInTopN is for test
func TestKillSignalInTopN(t *testing.T, topnExec *TopNExec) {
ctx := context.Background()
err := topnExec.Open(ctx)
require.NoError(t, err)

chkHeap := &topNChunkHeap{}
// Offset of heap in worker should be 0, as we need to spill all data
chkHeap.init(topnExec, topnExec.memTracker, topnExec.Limit.Offset+topnExec.Limit.Count, 0, topnExec.greaterRow, topnExec.RetFieldTypes())
srcChk := exec.TryNewCacheChunk(topnExec.Children(0))
err = exec.Next(ctx, topnExec.Children(0), srcChk)
require.NoError(t, err)
chkHeap.rowChunks.Add(srcChk)

topnExec.Ctx().GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted)
err = topnExec.spillHelper.spillHeap(chkHeap)
require.Error(t, err, exeerrors.ErrQueryInterrupted.GenWithStackByArgs())
}
12 changes: 12 additions & 0 deletions pkg/executor/sortexec/topn_spill.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/util/disk"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
"go.uber.org/zap"
)

Expand All @@ -47,6 +48,8 @@ type topNSpillHelper struct {

bytesConsumed atomic.Int64
bytesLimit atomic.Int64

sqlKiller *sqlkiller.SQLKiller
}

func newTopNSpillerHelper(
Expand All @@ -58,6 +61,7 @@ func newTopNSpillerHelper(
fieldTypes []*types.FieldType,
workers []*topNWorker,
concurrencyNum int,
sqlKiller *sqlkiller.SQLKiller,
) *topNSpillHelper {
lock := sync.Mutex{}
tmpSpillChunksChan := make(chan *chunk.Chunk, concurrencyNum)
Expand All @@ -78,6 +82,7 @@ func newTopNSpillerHelper(
workers: workers,
bytesConsumed: atomic.Int64{},
bytesLimit: atomic.Int64{},
sqlKiller: sqlKiller,
}
}

Expand Down Expand Up @@ -209,6 +214,13 @@ func (t *topNSpillHelper) spillHeap(chkHeap *topNChunkHeap) error {

rowPtrNum := chkHeap.Len()
for ; chkHeap.idx < rowPtrNum; chkHeap.idx++ {
if chkHeap.idx%100 == 0 && t.sqlKiller != nil {
err := t.sqlKiller.HandleSignal()
if err != nil {
return err
}
}

if tmpSpillChunk.IsFull() {
err := t.spillTmpSpillChunk(inDisk, tmpSpillChunk)
if err != nil {
Expand Down
22 changes: 22 additions & 0 deletions pkg/executor/sortexec/topn_spill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,28 @@ func TestIssue54206(t *testing.T) {
tk.MustQuery("select t1.a+t1.b as result from t1 left join t2 on 1 = 0 order by result limit 1;")
}

func TestIssue54541(t *testing.T) {
totalRowNum := 30
sortexec.SetSmallSpillChunkSizeForTest()
ctx := mock.NewContext()
topNCase := &testutil.SortCase{Rows: totalRowNum, OrderByIdx: []int{0, 1}, Ndvs: []int{0, 0}, Ctx: ctx}

ctx.GetSessionVars().InitChunkSize = 32
ctx.GetSessionVars().MaxChunkSize = 32
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)

offset := uint64(totalRowNum / 10)
count := uint64(totalRowNum / 3)

schema := expression.NewSchema(topNCase.Columns()...)
dataSource := buildDataSource(topNCase, schema)
exe := buildTopNExec(topNCase, dataSource, offset, count)

sortexec.TestKillSignalInTopN(t, exe)
}

func TestTopNFallBackAction(t *testing.T) {
sortexec.SetSmallSpillChunkSizeForTest()
ctx := mock.NewContext()
Expand Down

0 comments on commit 179521a

Please sign in to comment.