From 40c4a1675309a2b240df6c1d6ebcc0fb40409446 Mon Sep 17 00:00:00 2001 From: tangenta Date: Tue, 29 Aug 2023 14:59:15 +0800 Subject: [PATCH] ddl, disttask: implement add index operators (#46414) ref pingcap/tidb#46258 --- ddl/BUILD.bazel | 2 + ddl/backfilling_operator.go | 514 ++++++++++++++++++ ddl/backfilling_scheduler.go | 39 +- ddl/export_test.go | 10 +- ddl/index.go | 4 +- ddl/index_cop.go | 28 +- ddl/ingest/mock.go | 25 +- disttask/operator/compose.go | 2 + disttask/operator/operator.go | 21 +- disttask/operator/wrapper.go | 3 +- executor/executor.go | 7 +- resourcemanager/pool/workerpool/BUILD.bazel | 2 +- resourcemanager/pool/workerpool/workerpool.go | 37 +- .../pool/workerpool/workpool_test.go | 44 +- tests/realtikvtest/addindextest/BUILD.bazel | 9 + .../addindextest/operator_test.go | 319 +++++++++++ 16 files changed, 992 insertions(+), 74 deletions(-) create mode 100644 ddl/backfilling_operator.go create mode 100644 tests/realtikvtest/addindextest/operator_test.go diff --git a/ddl/BUILD.bazel b/ddl/BUILD.bazel index fc4c4bf7b201c..b5af91de6640c 100644 --- a/ddl/BUILD.bazel +++ b/ddl/BUILD.bazel @@ -12,6 +12,7 @@ go_library( name = "ddl", srcs = [ "backfilling.go", + "backfilling_operator.go", "backfilling_scheduler.go", "callback.go", "cluster.go", @@ -73,6 +74,7 @@ go_library( "//disttask/framework/handle", "//disttask/framework/proto", "//disttask/framework/scheduler", + "//disttask/operator", "//domain/infosync", "//domain/resourcegroup", "//expression", diff --git a/ddl/backfilling_operator.go b/ddl/backfilling_operator.go new file mode 100644 index 0000000000000..e0a4059462aa2 --- /dev/null +++ b/ddl/backfilling_operator.go @@ -0,0 +1,514 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "encoding/hex" + "fmt" + "sync/atomic" + "time" + + "github.com/pingcap/tidb/ddl/ingest" + "github.com/pingcap/tidb/ddl/internal/session" + "github.com/pingcap/tidb/disttask/operator" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/resourcemanager/pool/workerpool" + "github.com/pingcap/tidb/resourcemanager/util" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/tablecodec" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var ( + _ operator.Operator = (*TableScanTaskSource)(nil) + _ operator.WithSink[TableScanTask] = (*TableScanTaskSource)(nil) + + _ operator.WithSource[TableScanTask] = (*TableScanOperator)(nil) + _ operator.Operator = (*TableScanOperator)(nil) + _ operator.WithSink[IndexRecordChunk] = (*TableScanOperator)(nil) + + _ operator.WithSource[IndexRecordChunk] = (*IndexIngestOperator)(nil) + _ operator.Operator = (*IndexIngestOperator)(nil) + _ operator.WithSink[IndexWriteResult] = (*IndexIngestOperator)(nil) + + _ operator.WithSource[IndexWriteResult] = (*indexWriteResultSink)(nil) + _ operator.Operator = (*indexWriteResultSink)(nil) +) + +type opSessPool interface { + Get() (sessionctx.Context, error) + Put(sessionctx.Context) +} + +// NewAddIndexIngestPipeline creates a pipeline for adding index in ingest mode. +// TODO(tangenta): add failpoint tests for these operators to ensure the robustness. +func NewAddIndexIngestPipeline( + ctx context.Context, + store kv.Storage, + sessPool opSessPool, + engine ingest.Engine, + sessCtx sessionctx.Context, + tbl table.PhysicalTable, + idxInfo *model.IndexInfo, + startKey, endKey kv.Key, +) (*operator.AsyncPipeline, error) { + index := tables.NewIndex(tbl.GetPhysicalID(), tbl.Meta(), idxInfo) + copCtx, err := NewCopContext(tbl.Meta(), idxInfo, sessCtx) + if err != nil { + return nil, err + } + poolSize := copReadChunkPoolSize() + srcChkPool := make(chan *chunk.Chunk, poolSize) + for i := 0; i < poolSize; i++ { + srcChkPool <- chunk.NewChunkWithCapacity(copCtx.fieldTps, copReadBatchSize()) + } + readerCnt, writerCnt := expectedIngestWorkerCnt() + + srcOp := NewTableScanTaskSource(ctx, store, tbl, startKey, endKey) + scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt) + ingestOp := NewIndexIngestOperator(ctx, copCtx, sessPool, tbl, index, engine, srcChkPool, writerCnt) + sinkOp := newIndexWriteResultSink(ctx) + + operator.Compose[TableScanTask](srcOp, scanOp) + operator.Compose[IndexRecordChunk](scanOp, ingestOp) + operator.Compose[IndexWriteResult](ingestOp, sinkOp) + + return operator.NewAsyncPipeline( + srcOp, scanOp, ingestOp, sinkOp, + ), nil +} + +// TableScanTask contains the start key and the end key of a region. +type TableScanTask struct { + ID int + Start kv.Key + End kv.Key +} + +// String implement fmt.Stringer interface. +func (t *TableScanTask) String() string { + return fmt.Sprintf("TableScanTask: id=%d, startKey=%s, endKey=%s", + t.ID, hex.EncodeToString(t.Start), hex.EncodeToString(t.End)) +} + +// IndexRecordChunk contains one of the chunk read from corresponding TableScanTask. +type IndexRecordChunk struct { + ID int + Chunk *chunk.Chunk + Err error + Done bool +} + +// TableScanTaskSource produces TableScanTask by splitting table records into ranges. +type TableScanTaskSource struct { + ctx context.Context + + errGroup errgroup.Group + sink operator.DataChannel[TableScanTask] + + tbl table.PhysicalTable + store kv.Storage + startKey kv.Key + endKey kv.Key +} + +// NewTableScanTaskSource creates a new TableScanTaskSource. +func NewTableScanTaskSource( + ctx context.Context, + store kv.Storage, + physicalTable table.PhysicalTable, + startKey kv.Key, + endKey kv.Key, +) *TableScanTaskSource { + return &TableScanTaskSource{ + ctx: ctx, + errGroup: errgroup.Group{}, + tbl: physicalTable, + store: store, + startKey: startKey, + endKey: endKey, + } +} + +// SetSink implements WithSink interface. +func (src *TableScanTaskSource) SetSink(sink operator.DataChannel[TableScanTask]) { + src.sink = sink +} + +// Open implements Operator interface. +func (src *TableScanTaskSource) Open() error { + src.errGroup.Go(src.generateTasks) + return nil +} + +func (src *TableScanTaskSource) generateTasks() error { + taskIDAlloc := newTaskIDAllocator() + defer src.sink.Finish() + startKey := src.startKey + endKey := src.endKey + for { + kvRanges, err := splitTableRanges( + src.tbl, + src.store, + startKey, + endKey, + backfillTaskChanSize, + ) + if err != nil { + return err + } + if len(kvRanges) == 0 { + break + } + + batchTasks := src.getBatchTableScanTask(kvRanges, taskIDAlloc) + for _, task := range batchTasks { + select { + case <-src.ctx.Done(): + return src.ctx.Err() + case src.sink.Channel() <- task: + } + } + startKey = kvRanges[len(kvRanges)-1].EndKey + if startKey.Cmp(endKey) >= 0 { + break + } + } + return nil +} + +func (src *TableScanTaskSource) getBatchTableScanTask( + kvRanges []kv.KeyRange, + taskIDAlloc *taskIDAllocator, +) []TableScanTask { + batchTasks := make([]TableScanTask, 0, len(kvRanges)) + prefix := src.tbl.RecordPrefix() + // Build reorg tasks. + for _, keyRange := range kvRanges { + taskID := taskIDAlloc.alloc() + startKey := keyRange.StartKey + if len(startKey) == 0 { + startKey = prefix + } + endKey := keyRange.EndKey + if len(endKey) == 0 { + endKey = prefix.PrefixNext() + } + + task := TableScanTask{ + ID: taskID, + Start: startKey, + End: endKey, + } + batchTasks = append(batchTasks, task) + } + return batchTasks +} + +// Close implements Operator interface. +func (src *TableScanTaskSource) Close() error { + return src.errGroup.Wait() +} + +// String implements fmt.Stringer interface. +func (*TableScanTaskSource) String() string { + return "TableScanTaskSource" +} + +// TableScanOperator scans table records in given key ranges from kv store. +type TableScanOperator struct { + *operator.AsyncOperator[TableScanTask, IndexRecordChunk] +} + +// NewTableScanOperator creates a new TableScanOperator. +func NewTableScanOperator( + ctx context.Context, + sessPool opSessPool, + copCtx *copContext, + srcChkPool chan *chunk.Chunk, + concurrency int, +) *TableScanOperator { + pool := workerpool.NewWorkerPool( + "TableScanOperator", + util.DDL, + concurrency, + func() workerpool.Worker[TableScanTask, IndexRecordChunk] { + return &tableScanWorker{ + ctx: ctx, + copCtx: copCtx, + sessPool: sessPool, + se: nil, + srcChkPool: srcChkPool, + } + }) + return &TableScanOperator{ + AsyncOperator: operator.NewAsyncOperator[TableScanTask, IndexRecordChunk](ctx, pool), + } +} + +type tableScanWorker struct { + ctx context.Context + copCtx *copContext + sessPool opSessPool + se *session.Session + srcChkPool chan *chunk.Chunk +} + +func (w *tableScanWorker) HandleTask(task TableScanTask, sender func(IndexRecordChunk)) { + if w.se == nil { + sessCtx, err := w.sessPool.Get() + if err != nil { + logutil.Logger(w.ctx).Error("tableScanWorker get session from pool failed", zap.Error(err)) + sender(IndexRecordChunk{Err: err}) + return + } + w.se = session.NewSession(sessCtx) + } + w.scanRecords(task, sender) +} + +func (w *tableScanWorker) Close() { + w.sessPool.Put(w.se.Context) +} + +func (w *tableScanWorker) scanRecords(task TableScanTask, sender func(IndexRecordChunk)) { + logutil.Logger(w.ctx).Info("start a table scan task", + zap.Int("id", task.ID), zap.String("task", task.String())) + + var idxResult IndexRecordChunk + err := wrapInBeginRollback(w.se, func(startTS uint64) error { + rs, err := w.copCtx.buildTableScan(w.ctx, startTS, task.Start, task.End) + if err != nil { + return err + } + var done bool + for !done { + srcChk := w.getChunk() + done, err = w.copCtx.fetchTableScanResult(w.ctx, rs, srcChk) + if err != nil { + w.recycleChunk(srcChk) + terror.Call(rs.Close) + return err + } + idxResult = IndexRecordChunk{ID: task.ID, Chunk: srcChk, Done: done} + sender(idxResult) + } + return rs.Close() + }) + if err != nil { + // TODO(tangenta): cancel operator instead of sending error to sink. + idxResult.Err = err + sender(idxResult) + } +} + +func (w *tableScanWorker) getChunk() *chunk.Chunk { + chk := <-w.srcChkPool + newCap := copReadBatchSize() + if chk.Capacity() != newCap { + chk = chunk.NewChunkWithCapacity(w.copCtx.fieldTps, newCap) + } + chk.Reset() + return chk +} + +func (w *tableScanWorker) recycleChunk(chk *chunk.Chunk) { + w.srcChkPool <- chk +} + +// IndexWriteResult contains the result of writing index records to ingest engine. +type IndexWriteResult struct { + ID int + Added int + Total int + Next kv.Key + Err error +} + +// IndexIngestOperator writes index records to ingest engine. +type IndexIngestOperator struct { + *operator.AsyncOperator[IndexRecordChunk, IndexWriteResult] +} + +// NewIndexIngestOperator creates a new IndexIngestOperator. +func NewIndexIngestOperator( + ctx context.Context, + copCtx *copContext, + sessPool opSessPool, + tbl table.PhysicalTable, + index table.Index, + engine ingest.Engine, + srcChunkPool chan *chunk.Chunk, + concurrency int, +) *IndexIngestOperator { + var writerIDAlloc atomic.Int32 + pool := workerpool.NewWorkerPool( + "indexIngestOperator", + util.DDL, + concurrency, + func() workerpool.Worker[IndexRecordChunk, IndexWriteResult] { + writerID := int(writerIDAlloc.Add(1)) + writer, err := engine.CreateWriter(writerID, index.Meta().Unique) + if err != nil { + logutil.Logger(ctx).Error("create index ingest worker failed", zap.Error(err)) + return nil + } + return &indexIngestWorker{ + ctx: ctx, + tbl: tbl, + index: index, + copCtx: copCtx, + se: nil, + sessPool: sessPool, + writer: writer, + srcChunkPool: srcChunkPool, + } + }) + return &IndexIngestOperator{ + AsyncOperator: operator.NewAsyncOperator[IndexRecordChunk, IndexWriteResult](ctx, pool), + } +} + +type indexIngestWorker struct { + ctx context.Context + + tbl table.PhysicalTable + index table.Index + + copCtx *copContext + sessPool opSessPool + se *session.Session + + writer ingest.Writer + srcChunkPool chan *chunk.Chunk +} + +func (w *indexIngestWorker) HandleTask(rs IndexRecordChunk, send func(IndexWriteResult)) { + defer func() { + if rs.Chunk != nil { + w.srcChunkPool <- rs.Chunk + } + }() + result := IndexWriteResult{ + ID: rs.ID, + Err: rs.Err, + } + if w.se == nil { + sessCtx, err := w.sessPool.Get() + if err != nil { + result.Err = err + send(result) + return + } + w.se = session.NewSession(sessCtx) + } + if result.Err != nil { + logutil.Logger(w.ctx).Error("encounter error when handle index chunk", + zap.Int("id", rs.ID), zap.Error(rs.Err)) + send(result) + return + } + count, nextKey, err := w.WriteLocal(&rs) + if err != nil { + result.Err = err + send(result) + return + } + if count == 0 { + logutil.Logger(w.ctx).Info("finish a index ingest task", zap.Int("id", rs.ID)) + send(result) + return + } + result.Added = count + result.Next = nextKey + if ResultCounterForTest != nil && result.Err == nil { + ResultCounterForTest.Add(1) + } + send(result) +} + +func (*indexIngestWorker) Close() { +} + +// WriteLocal will write index records to lightning engine. +func (w *indexIngestWorker) WriteLocal(rs *IndexRecordChunk) (count int, nextKey kv.Key, err error) { + oprStartTime := time.Now() + vars := w.se.GetSessionVars() + cnt, lastHandle, err := writeChunkToLocal(w.writer, w.index, w.copCtx, vars, rs.Chunk) + if err != nil || cnt == 0 { + return 0, nil, err + } + logSlowOperations(time.Since(oprStartTime), "writeChunkToLocal", 3000) + nextKey = tablecodec.EncodeRecordKey(w.tbl.RecordPrefix(), lastHandle) + return cnt, nextKey, nil +} + +type indexWriteResultSink struct { + ctx context.Context + + errGroup errgroup.Group + source operator.DataChannel[IndexWriteResult] +} + +func newIndexWriteResultSink( + ctx context.Context, +) *indexWriteResultSink { + return &indexWriteResultSink{ + ctx: ctx, + errGroup: errgroup.Group{}, + } +} + +func (s *indexWriteResultSink) SetSource(source operator.DataChannel[IndexWriteResult]) { + s.source = source +} + +func (s *indexWriteResultSink) Open() error { + s.errGroup.Go(s.collectResult) + return nil +} + +func (s *indexWriteResultSink) collectResult() error { + // TODO(tangenta): use results to update reorg info and metrics. + for { + select { + case <-s.ctx.Done(): + return nil + case result, ok := <-s.source.Channel(): + if !ok { + return nil + } + if result.Err != nil { + return result.Err + } + } + } +} + +func (s *indexWriteResultSink) Close() error { + return s.errGroup.Wait() +} + +func (*indexWriteResultSink) String() string { + return "indexWriteResultSink" +} diff --git a/ddl/backfilling_scheduler.go b/ddl/backfilling_scheduler.go index 209bc15277224..9dac7c81e31c1 100644 --- a/ddl/backfilling_scheduler.go +++ b/ddl/backfilling_scheduler.go @@ -268,7 +268,7 @@ type ingestBackfillScheduler struct { copReqSenderPool *copReqSenderPool - writerPool *workerpool.WorkerPool[idxRecResult, workerpool.None] + writerPool *workerpool.WorkerPool[IndexRecordChunk, workerpool.None] writerMaxID int poolErr chan error backendCtx ingest.BackendCtx @@ -308,9 +308,9 @@ func (b *ingestBackfillScheduler) setupWorkers() error { } b.copReqSenderPool = copReqSenderPool readerCnt, writerCnt := b.expectedWorkerSize() - writerPool := workerpool.NewWorkerPool[idxRecResult]("ingest_writer", + writerPool := workerpool.NewWorkerPool[IndexRecordChunk]("ingest_writer", poolutil.DDL, writerCnt, b.createWorker) - writerPool.Start() + writerPool.Start(b.ctx) b.writerPool = writerPool b.copReqSenderPool.chunkSender = writerPool b.copReqSenderPool.adjustSize(readerCnt) @@ -379,7 +379,7 @@ func (b *ingestBackfillScheduler) adjustWorkerSize() error { return nil } -func (b *ingestBackfillScheduler) createWorker() workerpool.Worker[idxRecResult, workerpool.None] { +func (b *ingestBackfillScheduler) createWorker() workerpool.Worker[IndexRecordChunk, workerpool.None] { reorgInfo := b.reorgInfo job := reorgInfo.Job sessCtx, err := newSessCtx(reorgInfo) @@ -428,7 +428,7 @@ func (b *ingestBackfillScheduler) createCopReqSenderPool() (*copReqSenderPool, e logutil.Logger(b.ctx).Warn("cannot init cop request sender", zap.Error(err)) return nil, err } - copCtx, err := newCopContext(b.tbl.Meta(), indexInfo, sessCtx) + copCtx, err := NewCopContext(b.tbl.Meta(), indexInfo, sessCtx) if err != nil { logutil.Logger(b.ctx).Warn("cannot init cop request sender", zap.Error(err)) return nil, err @@ -437,25 +437,29 @@ func (b *ingestBackfillScheduler) createCopReqSenderPool() (*copReqSenderPool, e } func (*ingestBackfillScheduler) expectedWorkerSize() (readerSize int, writerSize int) { + return expectedIngestWorkerCnt() +} + +func expectedIngestWorkerCnt() (readerCnt, writerCnt int) { workerCnt := int(variable.GetDDLReorgWorkerCounter()) - readerSize = mathutil.Min(workerCnt/2, maxBackfillWorkerSize) - readerSize = mathutil.Max(readerSize, 1) - writerSize = mathutil.Min(workerCnt/2+2, maxBackfillWorkerSize) - return readerSize, writerSize + readerCnt = mathutil.Min(workerCnt/2, maxBackfillWorkerSize) + readerCnt = mathutil.Max(readerCnt, 1) + writerCnt = mathutil.Min(workerCnt/2+2, maxBackfillWorkerSize) + return readerCnt, writerCnt } -func (w *addIndexIngestWorker) HandleTask(rs idxRecResult) (_ workerpool.None) { +func (w *addIndexIngestWorker) HandleTask(rs IndexRecordChunk, _ func(workerpool.None)) { defer util.Recover(metrics.LabelDDL, "ingestWorker.HandleTask", func() { - w.resultCh <- &backfillResult{taskID: rs.id, err: dbterror.ErrReorgPanic} + w.resultCh <- &backfillResult{taskID: rs.ID, err: dbterror.ErrReorgPanic} }, false) - defer w.copReqSenderPool.recycleChunk(rs.chunk) + defer w.copReqSenderPool.recycleChunk(rs.Chunk) result := &backfillResult{ - taskID: rs.id, - err: rs.err, + taskID: rs.ID, + err: rs.Err, } if result.err != nil { logutil.Logger(w.ctx).Error("encounter error when handle index chunk", - zap.Int("id", rs.id), zap.Error(rs.err)) + zap.Int("id", rs.ID), zap.Error(rs.Err)) w.resultCh <- result return } @@ -474,14 +478,14 @@ func (w *addIndexIngestWorker) HandleTask(rs idxRecResult) (_ workerpool.None) { return } if count == 0 { - logutil.Logger(w.ctx).Info("finish a cop-request task", zap.Int("id", rs.id)) + logutil.Logger(w.ctx).Info("finish a cop-request task", zap.Int("id", rs.ID)) return } if w.checkpointMgr != nil { cnt, nextKey := w.checkpointMgr.Status() result.totalCount = cnt result.nextKey = nextKey - result.err = w.checkpointMgr.UpdateCurrent(rs.id, count) + result.err = w.checkpointMgr.UpdateCurrent(rs.ID, count) } else { result.addedCount = count result.scanCount = count @@ -491,7 +495,6 @@ func (w *addIndexIngestWorker) HandleTask(rs idxRecResult) (_ workerpool.None) { ResultCounterForTest.Add(1) } w.resultCh <- result - return } func (*addIndexIngestWorker) Close() {} diff --git a/ddl/export_test.go b/ddl/export_test.go index 633c28c8a9de6..9bc3581d9ca86 100644 --- a/ddl/export_test.go +++ b/ddl/export_test.go @@ -27,13 +27,13 @@ import ( "github.com/pingcap/tidb/util/chunk" ) -var NewCopContext4Test = newCopContext +var NewCopContext4Test = NewCopContext type resultChanForTest struct { - ch chan idxRecResult + ch chan IndexRecordChunk } -func (r *resultChanForTest) AddTask(rs idxRecResult) { +func (r *resultChanForTest) AddTask(rs IndexRecordChunk) { r.ch <- rs } @@ -47,7 +47,7 @@ func FetchChunk4Test(copCtx *copContext, tbl table.PhysicalTable, startKey, endK physicalTable: tbl, } taskCh := make(chan *reorgBackfillTask, 5) - resultCh := make(chan idxRecResult, 5) + resultCh := make(chan IndexRecordChunk, 5) sessPool := session.NewSessionPool(nil, store) pool := newCopReqSenderPool(context.Background(), copCtx, store, taskCh, sessPool, nil) pool.chunkSender = &resultChanForTest{ch: resultCh} @@ -56,7 +56,7 @@ func FetchChunk4Test(copCtx *copContext, tbl table.PhysicalTable, startKey, endK rs := <-resultCh close(taskCh) pool.close(false) - return rs.chunk + return rs.Chunk } func ConvertRowToHandleAndIndexDatum(row chunk.Row, copCtx *copContext) (kv.Handle, []types.Datum, error) { diff --git a/ddl/index.go b/ddl/index.go index df8f77be9d5a5..c2369fb290f76 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -1653,11 +1653,11 @@ func newAddIndexIngestWorker(ctx context.Context, t table.PhysicalTable, d *ddlC } // WriteLocal will write index records to lightning engine. -func (w *addIndexIngestWorker) WriteLocal(rs *idxRecResult) (count int, nextKey kv.Key, err error) { +func (w *addIndexIngestWorker) WriteLocal(rs *IndexRecordChunk) (count int, nextKey kv.Key, err error) { oprStartTime := time.Now() copCtx := w.copReqSenderPool.copCtx vars := w.sessCtx.GetSessionVars() - cnt, lastHandle, err := writeChunkToLocal(w.writer, w.index, copCtx, vars, rs.chunk) + cnt, lastHandle, err := writeChunkToLocal(w.writer, w.index, copCtx, vars, rs.Chunk) if err != nil || cnt == 0 { return 0, nil, err } diff --git a/ddl/index_cop.go b/ddl/index_cop.go index d3fdd59fa5ad1..7f703fab8268c 100644 --- a/ddl/index_cop.go +++ b/ddl/index_cop.go @@ -64,7 +64,7 @@ func copReadChunkPoolSize() int { // chunkSender is used to receive the result of coprocessor request. type chunkSender interface { - AddTask(idxRecResult) + AddTask(IndexRecordChunk) } type copReqSenderPool struct { @@ -95,12 +95,12 @@ func (c *copReqSender) run() { p := c.senderPool defer p.wg.Done() defer util.Recover(metrics.LabelDDL, "copReqSender.run", func() { - p.chunkSender.AddTask(idxRecResult{err: dbterror.ErrReorgPanic}) + p.chunkSender.AddTask(IndexRecordChunk{Err: dbterror.ErrReorgPanic}) }, false) sessCtx, err := p.sessPool.Get() if err != nil { logutil.Logger(p.ctx).Error("copReqSender get session from pool failed", zap.Error(err)) - p.chunkSender.AddTask(idxRecResult{err: err}) + p.chunkSender.AddTask(IndexRecordChunk{Err: err}) return } se := sess.NewSession(sessCtx) @@ -121,7 +121,7 @@ func (c *copReqSender) run() { } err := scanRecords(p, task, se) if err != nil { - p.chunkSender.AddTask(idxRecResult{id: task.id, err: err}) + p.chunkSender.AddTask(IndexRecordChunk{ID: task.id, Err: err}) return } } @@ -156,9 +156,9 @@ func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) if p.checkpointMgr != nil { p.checkpointMgr.UpdateTotal(task.id, srcChk.NumRows(), done) } - idxRs := idxRecResult{id: task.id, chunk: srcChk, done: done} + idxRs := IndexRecordChunk{ID: task.id, Chunk: srcChk, Done: done} failpoint.Inject("mockCopSenderError", func() { - idxRs.err = errors.New("mock cop error") + idxRs.Err = errors.New("mock cop error") }) p.chunkSender.AddTask(idxRs) } @@ -273,7 +273,14 @@ type copContext struct { virtualColFieldTps []*types.FieldType } -func newCopContext(tblInfo *model.TableInfo, idxInfo *model.IndexInfo, sessCtx sessionctx.Context) (*copContext, error) { +// FieldTypes is only used for test. +// TODO(tangenta): refactor the operators to avoid using this method. +func (c *copContext) FieldTypes() []*types.FieldType { + return c.fieldTps +} + +// NewCopContext creates a copContext. +func NewCopContext(tblInfo *model.TableInfo, idxInfo *model.IndexInfo, sessCtx sessionctx.Context) (*copContext, error) { var err error usedColumnIDs := make(map[int64]struct{}, len(idxInfo.Columns)) usedColumnIDs, err = fillUsedColumns(usedColumnIDs, idxInfo, tblInfo) @@ -526,10 +533,3 @@ func buildHandle(pkDts []types.Datum, tblInfo *model.TableInfo, } return kv.IntHandle(pkDts[0].GetInt64()), nil } - -type idxRecResult struct { - id int - chunk *chunk.Chunk - err error - done bool -} diff --git a/ddl/ingest/mock.go b/ddl/ingest/mock.go index 2b6e6f0d8bf28..b2e8d9ed32675 100644 --- a/ddl/ingest/mock.go +++ b/ddl/ingest/mock.go @@ -148,10 +148,23 @@ func (m *MockBackendCtx) GetCheckpointManager() *CheckpointManager { return m.checkpointMgr } +// MockWriteHook the hook for write in mock engine. +type MockWriteHook func(key, val []byte) + // MockEngineInfo is a mock engine info. type MockEngineInfo struct { sessCtx sessionctx.Context mu *sync.Mutex + + onWrite MockWriteHook +} + +// NewMockEngineInfo creates a new mock engine info. +func NewMockEngineInfo(sessCtx sessionctx.Context) *MockEngineInfo { + return &MockEngineInfo{ + sessCtx: sessCtx, + mu: &sync.Mutex{}, + } } // Flush implements Engine.Flush interface. @@ -168,16 +181,22 @@ func (*MockEngineInfo) ImportAndClean() error { func (*MockEngineInfo) Clean() { } +// SetHook set the write hook. +func (m *MockEngineInfo) SetHook(onWrite func(key, val []byte)) { + m.onWrite = onWrite +} + // CreateWriter implements Engine.CreateWriter interface. func (m *MockEngineInfo) CreateWriter(id int, _ bool) (Writer, error) { logutil.BgLogger().Info("mock engine info create writer", zap.Int("id", id)) - return &MockWriter{sessCtx: m.sessCtx, mu: m.mu}, nil + return &MockWriter{sessCtx: m.sessCtx, mu: m.mu, onWrite: m.onWrite}, nil } // MockWriter is a mock writer. type MockWriter struct { sessCtx sessionctx.Context mu *sync.Mutex + onWrite MockWriteHook } // WriteRow implements Writer.WriteRow interface. @@ -187,6 +206,10 @@ func (m *MockWriter) WriteRow(key, idxVal []byte, _ kv.Handle) error { zap.String("idxVal", hex.EncodeToString(idxVal))) m.mu.Lock() defer m.mu.Unlock() + if m.onWrite != nil { + m.onWrite(key, idxVal) + return nil + } txn, err := m.sessCtx.Txn(true) if err != nil { return err diff --git a/disttask/operator/compose.go b/disttask/operator/compose.go index fb1637c5154e9..cc5ba69b2062d 100644 --- a/disttask/operator/compose.go +++ b/disttask/operator/compose.go @@ -21,6 +21,8 @@ type WithSource[T any] interface { // WithSink is an interface that can be used to set the sink of an operator. type WithSink[T any] interface { + // SetSink sets the sink of the operator. + // Operator implementations should call the Finish method of the sink when they are done. SetSink(channel DataChannel[T]) } diff --git a/disttask/operator/operator.go b/disttask/operator/operator.go index c1baa8f45b1b8..df6480362ca2a 100644 --- a/disttask/operator/operator.go +++ b/disttask/operator/operator.go @@ -15,6 +15,7 @@ package operator import ( + "context" "fmt" "github.com/pingcap/tidb/resourcemanager/pool/workerpool" @@ -34,25 +35,32 @@ type Operator interface { // use the same channel, Then op2's worker will handle // the result from op1. type AsyncOperator[T, R any] struct { + ctx context.Context pool *workerpool.WorkerPool[T, R] } // NewAsyncOperatorWithTransform create an AsyncOperator with a transform function. -func NewAsyncOperatorWithTransform[T, R any](name string, workerNum int, transform func(T) R) *AsyncOperator[T, R] { +func NewAsyncOperatorWithTransform[T, R any]( + ctx context.Context, + name string, + workerNum int, + transform func(T) R, +) *AsyncOperator[T, R] { pool := workerpool.NewWorkerPool(name, util.DistTask, workerNum, newAsyncWorkerCtor(transform)) - return NewAsyncOperator(pool) + return NewAsyncOperator(ctx, pool) } // NewAsyncOperator create an AsyncOperator. -func NewAsyncOperator[T, R any](pool *workerpool.WorkerPool[T, R]) *AsyncOperator[T, R] { +func NewAsyncOperator[T, R any](ctx context.Context, pool *workerpool.WorkerPool[T, R]) *AsyncOperator[T, R] { return &AsyncOperator[T, R]{ + ctx: ctx, pool: pool, } } // Open implements the Operator's Open interface. func (c *AsyncOperator[T, R]) Open() error { - c.pool.Start() + c.pool.Start(c.ctx) return nil } @@ -95,8 +103,9 @@ func newAsyncWorkerCtor[T, R any](transform func(T) R) func() workerpool.Worker[ } } -func (s *asyncWorker[T, R]) HandleTask(task T) R { - return s.transform(task) +func (s *asyncWorker[T, R]) HandleTask(task T, rsFn func(R)) { + result := s.transform(task) + rsFn(result) } func (*asyncWorker[T, R]) Close() {} diff --git a/disttask/operator/wrapper.go b/disttask/operator/wrapper.go index d77c8abbc82e8..d800eb825791c 100644 --- a/disttask/operator/wrapper.go +++ b/disttask/operator/wrapper.go @@ -15,6 +15,7 @@ package operator import ( + "context" "fmt" "golang.org/x/sync/errgroup" @@ -104,7 +105,7 @@ func (s *simpleOperator[T, R]) String() string { } func newSimpleOperator[T, R any](transform func(task T) R, concurrency int) *simpleOperator[T, R] { - asyncOp := NewAsyncOperatorWithTransform("simple", concurrency, transform) + asyncOp := NewAsyncOperatorWithTransform(context.Background(), "simple", concurrency, transform) return &simpleOperator[T, R]{ AsyncOperator: asyncOp, } diff --git a/executor/executor.go b/executor/executor.go index 18988a8f1f10d..12aaf6ff36c41 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -2352,7 +2352,7 @@ func getCheckSum(ctx context.Context, se sessionctx.Context, sql string) ([]grou } // HandleTask implements the Worker interface. -func (w *checkIndexWorker) HandleTask(task checkIndexTask) (_ workerpool.None) { +func (w *checkIndexWorker) HandleTask(task checkIndexTask, _ func(workerpool.None)) { defer w.e.wg.Done() idxInfo := w.indexInfos[task.indexOffset] bucketSize := int(CheckTableFastBucketSize.Load()) @@ -2690,7 +2690,6 @@ func (w *checkIndexWorker) HandleTask(task checkIndexTask) (_ workerpool.None) { } } } - return } // Close implements the Worker interface. @@ -2701,7 +2700,7 @@ func (e *FastCheckTableExec) createWorker() workerpool.Worker[checkIndexTask, wo } // Next implements the Executor Next interface. -func (e *FastCheckTableExec) Next(context.Context, *chunk.Chunk) error { +func (e *FastCheckTableExec) Next(ctx context.Context, _ *chunk.Chunk) error { if e.done || len(e.indexInfos) == 0 { return nil } @@ -2715,7 +2714,7 @@ func (e *FastCheckTableExec) Next(context.Context, *chunk.Chunk) error { workerPool := workerpool.NewWorkerPool[checkIndexTask]("checkIndex", poolutil.CheckTable, 3, e.createWorker) - workerPool.Start() + workerPool.Start(ctx) e.wg.Add(len(e.indexInfos)) for i := range e.indexInfos { diff --git a/resourcemanager/pool/workerpool/BUILD.bazel b/resourcemanager/pool/workerpool/BUILD.bazel index 30647736b656d..619e7eef69423 100644 --- a/resourcemanager/pool/workerpool/BUILD.bazel +++ b/resourcemanager/pool/workerpool/BUILD.bazel @@ -24,7 +24,7 @@ go_test( embed = [":workerpool"], flaky = True, race = "on", - shard_count = 3, + shard_count = 4, deps = [ "//resourcemanager/util", "//testkit/testsetup", diff --git a/resourcemanager/pool/workerpool/workerpool.go b/resourcemanager/pool/workerpool/workerpool.go index ef8ed41d28690..8be568c741a93 100644 --- a/resourcemanager/pool/workerpool/workerpool.go +++ b/resourcemanager/pool/workerpool/workerpool.go @@ -15,6 +15,7 @@ package workerpool import ( + "context" "time" "github.com/pingcap/tidb/metrics" @@ -26,12 +27,16 @@ import ( // Worker is worker interface. type Worker[T, R any] interface { - HandleTask(task T) R + // HandleTask consumes a task(T) and produces a result(R). + // The result is sent to the result channel by calling `send` function. + HandleTask(task T, send func(R)) Close() } // WorkerPool is a pool of workers. type WorkerPool[T, R any] struct { + ctx context.Context + cancel context.CancelFunc name string numWorkers int32 originWorkers int32 @@ -86,7 +91,7 @@ func (p *WorkerPool[T, R]) SetResultSender(sender chan R) { } // Start starts default count of workers. -func (p *WorkerPool[T, R]) Start() { +func (p *WorkerPool[T, R]) Start(ctx context.Context) { if p.taskChan == nil { p.taskChan = make(chan T) } @@ -99,6 +104,8 @@ func (p *WorkerPool[T, R]) Start() { } } + p.ctx, p.cancel = context.WithCancel(ctx) + for i := 0; i < int(p.numWorkers); i++ { p.runAWorker() } @@ -111,10 +118,17 @@ func (p *WorkerPool[T, R]) handleTaskWithRecover(w Worker[T, R], task T) { }() defer tidbutil.Recover(metrics.LabelWorkerPool, "handleTaskWithRecover", nil, false) - r := w.HandleTask(task) - if p.resChan != nil { - p.resChan <- r + sendResult := func(r R) { + if p.resChan == nil { + return + } + select { + case p.resChan <- r: + case <-p.ctx.Done(): + } } + + w.HandleTask(task, sendResult) } func (p *WorkerPool[T, R]) runAWorker() { @@ -134,6 +148,9 @@ func (p *WorkerPool[T, R]) runAWorker() { case <-p.quitChan: w.Close() return + case <-p.ctx.Done(): + w.Close() + return } } }) @@ -197,10 +214,8 @@ func (p *WorkerPool[T, R]) Name() string { // ReleaseAndWait releases the pool and wait for complete. func (p *WorkerPool[T, R]) ReleaseAndWait() { close(p.quitChan) - p.wg.Wait() - if p.resChan != nil { - close(p.resChan) - } + p.Release() + p.Wait() } // Wait waits for all workers to complete. @@ -210,8 +225,12 @@ func (p *WorkerPool[T, R]) Wait() { // Release releases the pool. func (p *WorkerPool[T, R]) Release() { + if p.cancel != nil { + p.cancel() + } if p.resChan != nil { close(p.resChan) + p.resChan = nil } } diff --git a/resourcemanager/pool/workerpool/workpool_test.go b/resourcemanager/pool/workerpool/workpool_test.go index 0de9cab63f810..560b46720e1ab 100644 --- a/resourcemanager/pool/workerpool/workpool_test.go +++ b/resourcemanager/pool/workerpool/workpool_test.go @@ -15,6 +15,7 @@ package workerpool import ( + "context" "sync" "sync/atomic" "testing" @@ -33,11 +34,10 @@ type MyWorker[T int64, R struct{}] struct { id int } -func (w *MyWorker[T, R]) HandleTask(task int64) struct{} { +func (w *MyWorker[T, R]) HandleTask(task int64, _ func(struct{})) { globalCnt.Add(task) cntWg.Done() logutil.BgLogger().Info("Worker handling task") - return struct{}{} } func (w *MyWorker[T, R]) Close() { @@ -51,13 +51,14 @@ func createMyWorker() Worker[int64, struct{}] { func TestWorkerPool(t *testing.T) { // Create a worker pool with 3 workers. pool := NewWorkerPool[int64]("test", util.UNKNOWN, 3, createMyWorker) - pool.Start() + pool.Start(context.Background()) globalCnt.Store(0) g := new(errgroup.Group) + resultCh := pool.GetResultChan() g.Go(func() error { // Consume the results. - for range pool.GetResultChan() { + for range resultCh { // Do nothing. } return nil @@ -107,9 +108,9 @@ func TestWorkerPool(t *testing.T) { type dummyWorker[T, R any] struct { } -func (d dummyWorker[T, R]) HandleTask(task T) R { - var zero R - return zero +func (d dummyWorker[T, R]) HandleTask(task T, send func(R)) { + var r R + send(r) } func (d dummyWorker[T, R]) Close() {} @@ -120,7 +121,7 @@ func TestWorkerPoolNoneResult(t *testing.T) { func() Worker[int64, None] { return dummyWorker[int64, None]{} }) - pool.Start() + pool.Start(context.Background()) ch := pool.GetResultChan() require.Nil(t, ch) pool.ReleaseAndWait() @@ -130,7 +131,7 @@ func TestWorkerPoolNoneResult(t *testing.T) { func() Worker[int64, int64] { return dummyWorker[int64, int64]{} }) - pool2.Start() + pool2.Start(context.Background()) require.NotNil(t, pool2.GetResultChan()) pool2.ReleaseAndWait() @@ -139,7 +140,7 @@ func TestWorkerPoolNoneResult(t *testing.T) { func() Worker[int64, struct{}] { return dummyWorker[int64, struct{}]{} }) - pool3.Start() + pool3.Start(context.Background()) require.NotNil(t, pool3.GetResultChan()) pool3.ReleaseAndWait() } @@ -156,7 +157,7 @@ func TestWorkerPoolCustomChan(t *testing.T) { resultCh := make(chan int64) pool.SetResultSender(resultCh) count := 0 - g := new(errgroup.Group) + g := errgroup.Group{} g.Go(func() error { for range resultCh { count++ @@ -164,11 +165,28 @@ func TestWorkerPoolCustomChan(t *testing.T) { return nil }) - pool.Start() + pool.Start(context.Background()) for i := 0; i < 5; i++ { taskCh <- int64(i) } - pool.ReleaseAndWait() + close(taskCh) + pool.Wait() + pool.Release() require.NoError(t, g.Wait()) require.Equal(t, 5, count) } + +func TestWorkerPoolCancelContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := NewWorkerPool[int64, int64]( + "test", util.UNKNOWN, 3, + func() Worker[int64, int64] { + return dummyWorker[int64, int64]{} + }) + pool.Start(ctx) + pool.AddTask(1) + + cancel() + pool.Wait() // Should not be blocked by the result channel. + require.Equal(t, 0, int(pool.Running())) +} diff --git a/tests/realtikvtest/addindextest/BUILD.bazel b/tests/realtikvtest/addindextest/BUILD.bazel index 4a8c6d0314d15..f32bb825e9189 100644 --- a/tests/realtikvtest/addindextest/BUILD.bazel +++ b/tests/realtikvtest/addindextest/BUILD.bazel @@ -29,6 +29,7 @@ go_test( "integration_test.go", "main_test.go", "multi_schema_change_test.go", + "operator_test.go", "pitr_test.go", ], embed = [":addindextest"], @@ -39,13 +40,21 @@ go_test( "//ddl/ingest", "//ddl/testutil", "//ddl/util/callback", + "//disttask/operator", "//errno", + "//kv", "//parser/model", + "//sessionctx", "//sessionctx/variable", + "//table", + "//table/tables", "//testkit", "//tests/realtikvtest", + "//util/chunk", + "@com_github_ngaut_pools//:pools", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", + "@org_golang_x_sync//errgroup", ], ) diff --git a/tests/realtikvtest/addindextest/operator_test.go b/tests/realtikvtest/addindextest/operator_test.go new file mode 100644 index 0000000000000..28fb9a692390b --- /dev/null +++ b/tests/realtikvtest/addindextest/operator_test.go @@ -0,0 +1,319 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package addindextest + +import ( + "context" + "fmt" + "testing" + + "github.com/ngaut/pools" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/ddl/ingest" + "github.com/pingcap/tidb/disttask/operator" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/tests/realtikvtest" + "github.com/pingcap/tidb/util/chunk" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestBackfillOperators(t *testing.T) { + store, dom := realtikvtest.CreateMockStoreAndDomainAndSetup(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("drop database if exists op;") + tk.MustExec("create database op;") + tk.MustExec("use op;") + tk.MustExec(`set global tidb_ddl_enable_fast_reorg=on;`) + + tk.MustExec("create table t(a int primary key, b int, index idx(b));") + for i := 0; i < 10; i++ { + tk.MustExec("insert into t values (?, ?)", i*10000, i) + } + regionCnt := 10 + tk.MustQuery(fmt.Sprintf("split table t between (0) and (100000) regions %d;", regionCnt)). + Check(testkit.Rows(fmt.Sprintf("%d 1", regionCnt))) + // Refresh the region cache. + tk.MustQuery("select count(*) from t;").Check(testkit.Rows("10")) + + tbl, err := dom.InfoSchema().TableByName(model.NewCIStr("op"), model.NewCIStr("t")) + require.NoError(t, err) + startKey := tbl.RecordPrefix() + endKey := tbl.RecordPrefix().PrefixNext() + + tblInfo := tbl.Meta() + idxInfo := tblInfo.FindIndexByName("idx") + copCtx, err := ddl.NewCopContext(tblInfo, idxInfo, tk.Session()) + require.NoError(t, err) + + sessPool := newSessPoolForTest(t, store) + + // Test TableScanTaskSource operator. + var opTasks []ddl.TableScanTask + { + ctx := context.Background() + pTbl := tbl.(table.PhysicalTable) + src := ddl.NewTableScanTaskSource(ctx, store, pTbl, startKey, endKey) + sink := newTestSink[ddl.TableScanTask]() + + operator.Compose[ddl.TableScanTask](src, sink) + + pipeline := operator.NewAsyncPipeline(src, sink) + err = pipeline.Execute() + require.NoError(t, err) + err = pipeline.Close() + require.NoError(t, err) + + tasks := sink.collect() + require.Len(t, tasks, 10) + require.Equal(t, 1, tasks[0].ID) + require.Equal(t, startKey, tasks[0].Start) + require.Equal(t, endKey, tasks[9].End) + opTasks = tasks + } + + // Test TableScanOperator. + var chunkResults []ddl.IndexRecordChunk + { + // Make sure the buffer is large enough since the chunks do not recycled. + srcChkPool := make(chan *chunk.Chunk, regionCnt*2) + for i := 0; i < regionCnt*2; i++ { + srcChkPool <- chunk.NewChunkWithCapacity(copCtx.FieldTypes(), 100) + } + + ctx := context.Background() + src := newTestSource(opTasks...) + scanOp := ddl.NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, 3) + sink := newTestSink[ddl.IndexRecordChunk]() + + operator.Compose[ddl.TableScanTask](src, scanOp) + operator.Compose[ddl.IndexRecordChunk](scanOp, sink) + + pipeline := operator.NewAsyncPipeline(src, scanOp, sink) + err = pipeline.Execute() + require.NoError(t, err) + err = pipeline.Close() + require.NoError(t, err) + + results := sink.collect() + cnt := 0 + for _, rs := range results { + require.NoError(t, rs.Err) + chkRowCnt := rs.Chunk.NumRows() + cnt += chkRowCnt + if chkRowCnt > 0 { + chunkResults = append(chunkResults, rs) + } + } + require.Equal(t, 10, cnt) + } + + // Test IndexIngestOperator. + { + ctx := context.Background() + var keys, values [][]byte + onWrite := func(key, val []byte) { + keys = append(keys, key) + values = append(values, val) + } + + srcChkPool := make(chan *chunk.Chunk, regionCnt*2) + pTbl := tbl.(table.PhysicalTable) + index := tables.NewIndex(pTbl.GetPhysicalID(), tbl.Meta(), idxInfo) + mockEngine := ingest.NewMockEngineInfo(nil) + mockEngine.SetHook(onWrite) + + src := newTestSource(chunkResults...) + ingestOp := ddl.NewIndexIngestOperator(ctx, copCtx, sessPool, pTbl, index, mockEngine, srcChkPool, 3) + sink := newTestSink[ddl.IndexWriteResult]() + + operator.Compose[ddl.IndexRecordChunk](src, ingestOp) + operator.Compose[ddl.IndexWriteResult](ingestOp, sink) + + pipeline := operator.NewAsyncPipeline(src, ingestOp, sink) + err = pipeline.Execute() + require.NoError(t, err) + err = pipeline.Close() + require.NoError(t, err) + + results := sink.collect() + for _, rs := range results { + require.NoError(t, rs.Err) + } + require.Len(t, keys, 10) + require.Len(t, values, 10) + require.Len(t, sink.collect(), 10) + } +} + +func TestBackfillOperatorPipeline(t *testing.T) { + store, dom := realtikvtest.CreateMockStoreAndDomainAndSetup(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("drop database if exists op;") + tk.MustExec("create database op;") + tk.MustExec("use op;") + tk.MustExec(`set global tidb_ddl_enable_fast_reorg=on;`) + + tk.MustExec("create table t(a int primary key, b int, index idx(b));") + for i := 0; i < 10; i++ { + tk.MustExec("insert into t values (?, ?)", i*10000, i) + } + regionCnt := 10 + tk.MustQuery(fmt.Sprintf("split table t between (0) and (100000) regions %d;", regionCnt)). + Check(testkit.Rows(fmt.Sprintf("%d 1", regionCnt))) + // Refresh the region cache. + tk.MustQuery("select count(*) from t;").Check(testkit.Rows("10")) + + tbl, err := dom.InfoSchema().TableByName(model.NewCIStr("op"), model.NewCIStr("t")) + require.NoError(t, err) + startKey := tbl.RecordPrefix() + endKey := tbl.RecordPrefix().PrefixNext() + + tblInfo := tbl.Meta() + idxInfo := tblInfo.FindIndexByName("idx") + + sessPool := newSessPoolForTest(t, store) + + ctx := context.Background() + var keys, values [][]byte + onWrite := func(key, val []byte) { + keys = append(keys, key) + values = append(values, val) + } + mockEngine := ingest.NewMockEngineInfo(nil) + mockEngine.SetHook(onWrite) + + pipeline, err := ddl.NewAddIndexIngestPipeline( + ctx, store, + sessPool, + mockEngine, + tk.Session(), + tbl.(table.PhysicalTable), + idxInfo, + startKey, + endKey, + ) + require.NoError(t, err) + err = pipeline.Execute() + require.NoError(t, err) + err = pipeline.Close() + require.NoError(t, err) + require.Len(t, keys, 10) + require.Len(t, values, 10) +} + +type sessPoolForTest struct { + pool *pools.ResourcePool +} + +func newSessPoolForTest(t *testing.T, store kv.Storage) *sessPoolForTest { + return &sessPoolForTest{ + pool: pools.NewResourcePool(func() (pools.Resource, error) { + newTk := testkit.NewTestKit(t, store) + return newTk.Session(), nil + }, 8, 8, 0), + } +} + +func (p *sessPoolForTest) Get() (sessionctx.Context, error) { + resource, err := p.pool.Get() + if err != nil { + return nil, err + } + return resource.(sessionctx.Context), nil +} + +func (p *sessPoolForTest) Put(sctx sessionctx.Context) { + p.pool.Put(sctx.(pools.Resource)) +} + +type testSink[T any] struct { + errGroup errgroup.Group + ch chan T + collected []T +} + +func newTestSink[T any]() *testSink[T] { + return &testSink[T]{ + ch: make(chan T), + } +} + +func (s *testSink[T]) Open() error { + s.errGroup.Go(func() error { + for data := range s.ch { + s.collected = append(s.collected, data) + } + return nil + }) + return nil +} + +func (s *testSink[T]) Close() error { + return s.errGroup.Wait() +} + +func (s *testSink[T]) SetSource(dataCh operator.DataChannel[T]) { + s.ch = dataCh.Channel() +} + +func (s *testSink[T]) String() string { + return "testSink" +} + +func (s *testSink[T]) collect() []T { + return s.collected +} + +type testSource[T any] struct { + errGroup errgroup.Group + ch chan T + toBeSent []T +} + +func newTestSource[T any](toBeSent ...T) *testSource[T] { + return &testSource[T]{ + ch: make(chan T), + toBeSent: toBeSent, + } +} + +func (s *testSource[T]) SetSink(sink operator.DataChannel[T]) { + s.ch = sink.Channel() +} + +func (s *testSource[T]) Open() error { + s.errGroup.Go(func() error { + for _, data := range s.toBeSent { + s.ch <- data + } + close(s.ch) + return nil + }) + return nil +} + +func (s *testSource[T]) Close() error { + return s.errGroup.Wait() +} + +func (s *testSource[T]) String() string { + return "testSource" +}