diff --git a/Makefile b/Makefile index 66b3ba0686917..2f2dba4b010f0 100644 --- a/Makefile +++ b/Makefile @@ -415,7 +415,7 @@ bazel_coverage_test: failpoint-enable bazel_ci_prepare bazel_build: bazel_ci_prepare mkdir -p bin - bazel $(BAZEL_GLOBAL_CONFIG) build $(BAZEL_CMD_CONFIG) \ + bazel $(BAZEL_GLOBAL_CONFIG) build $(BAZEL_CMD_CONFIG) --remote_download_minimal \ //... --//build:with_nogo_flag=true bazel $(BAZEL_GLOBAL_CONFIG) build $(BAZEL_CMD_CONFIG) \ //cmd/importer:importer //tidb-server:tidb-server //tidb-server:tidb-server-check --//build:with_nogo_flag=true @@ -442,27 +442,27 @@ bazel_golangcilinter: -- run $$($(PACKAGE_DIRECTORIES)) --config ./.golangci.yaml bazel_brietest: failpoint-enable bazel_ci_prepare - bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --test_arg=-with-real-tikv \ + bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --remote_download_minimal --test_arg=-with-real-tikv \ -- //tests/realtikvtest/brietest/... bazel_pessimistictest: failpoint-enable bazel_ci_prepare - bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --test_arg=-with-real-tikv \ + bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --remote_download_minimal --test_arg=-with-real-tikv \ -- //tests/realtikvtest/pessimistictest/... bazel_sessiontest: failpoint-enable bazel_ci_prepare - bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --test_arg=-with-real-tikv \ + bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --remote_download_minimal --test_arg=-with-real-tikv \ -- //tests/realtikvtest/sessiontest/... bazel_statisticstest: failpoint-enable bazel_ci_prepare - bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --test_arg=-with-real-tikv \ + bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --remote_download_minimal --test_arg=-with-real-tikv \ -- //tests/realtikvtest/statisticstest/... bazel_txntest: failpoint-enable bazel_ci_prepare - bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --test_arg=-with-real-tikv \ + bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --remote_download_minimal --test_arg=-with-real-tikv \ -- //tests/realtikvtest/txntest/... bazel_addindextest: failpoint-enable bazel_ci_prepare - bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --test_arg=-with-real-tikv \ + bazel $(BAZEL_GLOBAL_CONFIG) test $(BAZEL_CMD_CONFIG) --remote_download_minimal --test_arg=-with-real-tikv \ -- //tests/realtikvtest/addindextest/... bazel_lint: bazel_prepare diff --git a/br/pkg/backup/client.go b/br/pkg/backup/client.go index 7614ca78e52c7..865e7fa2f3078 100644 --- a/br/pkg/backup/client.go +++ b/br/pkg/backup/client.go @@ -290,10 +290,12 @@ func appendRanges(tbl *model.TableInfo, tblID int64) ([]kv.KeyRange, error) { ranges = ranger.FullIntRange(false) } + retRanges := make([]kv.KeyRange, 0, 1+len(tbl.Indices)) kvRanges, err := distsql.TableHandleRangesToKVRanges(nil, []int64{tblID}, tbl.IsCommonHandle, ranges, nil) if err != nil { return nil, errors.Trace(err) } + retRanges = kvRanges.AppendSelfTo(retRanges) for _, index := range tbl.Indices { if index.State != model.StatePublic { @@ -304,9 +306,9 @@ func appendRanges(tbl *model.TableInfo, tblID int64) ([]kv.KeyRange, error) { if err != nil { return nil, errors.Trace(err) } - kvRanges = append(kvRanges, idxRanges...) + retRanges = idxRanges.AppendSelfTo(retRanges) } - return kvRanges, nil + return retRanges, nil } // BuildBackupRangeAndSchema gets KV range and schema of tables. diff --git a/br/pkg/checksum/executor_test.go b/br/pkg/checksum/executor_test.go index adcaed9c314f9..876103bc055a2 100644 --- a/br/pkg/checksum/executor_test.go +++ b/br/pkg/checksum/executor_test.go @@ -104,7 +104,7 @@ func TestChecksum(t *testing.T) { first = false ranges, err := backup.BuildTableRanges(tableInfo3) require.NoError(t, err) - require.Equalf(t, ranges[:1], req.KeyRanges, "%v", req.KeyRanges) + require.Equalf(t, ranges[:1], req.KeyRanges.FirstPartitionRange(), "%v", req.KeyRanges.FirstPartitionRange()) } return nil })) diff --git a/br/pkg/lightning/backend/local/duplicate.go b/br/pkg/lightning/backend/local/duplicate.go index b2858a8456f36..25bc7fabf514e 100644 --- a/br/pkg/lightning/backend/local/duplicate.go +++ b/br/pkg/lightning/backend/local/duplicate.go @@ -211,7 +211,7 @@ func physicalTableIDs(tableInfo *model.TableInfo) []int64 { } // tableHandleKeyRanges returns all key ranges associated with the tableInfo. -func tableHandleKeyRanges(tableInfo *model.TableInfo) ([]tidbkv.KeyRange, error) { +func tableHandleKeyRanges(tableInfo *model.TableInfo) (*tidbkv.KeyRanges, error) { ranges := ranger.FullIntRange(false) if tableInfo.IsCommonHandle { ranges = ranger.FullRange() @@ -221,18 +221,9 @@ func tableHandleKeyRanges(tableInfo *model.TableInfo) ([]tidbkv.KeyRange, error) } // tableIndexKeyRanges returns all key ranges associated with the tableInfo and indexInfo. -func tableIndexKeyRanges(tableInfo *model.TableInfo, indexInfo *model.IndexInfo) ([]tidbkv.KeyRange, error) { +func tableIndexKeyRanges(tableInfo *model.TableInfo, indexInfo *model.IndexInfo) (*tidbkv.KeyRanges, error) { tableIDs := physicalTableIDs(tableInfo) - //nolint: prealloc - var keyRanges []tidbkv.KeyRange - for _, tid := range tableIDs { - partitionKeysRanges, err := distsql.IndexRangesToKVRanges(nil, tid, indexInfo.ID, ranger.FullRange(), nil) - if err != nil { - return nil, errors.Trace(err) - } - keyRanges = append(keyRanges, partitionKeysRanges...) - } - return keyRanges, nil + return distsql.IndexRangesToKVRangesForTables(nil, tableIDs, indexInfo.ID, ranger.FullRange(), nil) } // DupKVStream is a streaming interface for collecting duplicate key-value pairs. @@ -561,14 +552,20 @@ func (m *DuplicateManager) buildDupTasks() ([]dupTask, error) { if err != nil { return nil, errors.Trace(err) } - tasks := make([]dupTask, 0, len(keyRanges)) - for _, kr := range keyRanges { - tableID := tablecodec.DecodeTableID(kr.StartKey) - tasks = append(tasks, dupTask{ - KeyRange: kr, - tableID: tableID, - }) + tasks := make([]dupTask, 0, keyRanges.TotalRangeNum()*(1+len(m.tbl.Meta().Indices))) + putToTaskFunc := func(ranges []tidbkv.KeyRange) { + if len(ranges) == 0 { + return + } + tid := tablecodec.DecodeTableID(ranges[0].StartKey) + for _, r := range ranges { + tasks = append(tasks, dupTask{ + KeyRange: r, + tableID: tid, + }) + } } + keyRanges.ForEachPartition(putToTaskFunc) for _, indexInfo := range m.tbl.Meta().Indices { if indexInfo.State != model.StatePublic { continue @@ -577,14 +574,7 @@ func (m *DuplicateManager) buildDupTasks() ([]dupTask, error) { if err != nil { return nil, errors.Trace(err) } - for _, kr := range keyRanges { - tableID := tablecodec.DecodeTableID(kr.StartKey) - tasks = append(tasks, dupTask{ - KeyRange: kr, - tableID: tableID, - indexInfo: indexInfo, - }) - } + keyRanges.ForEachPartition(putToTaskFunc) } return tasks, nil } @@ -598,15 +588,19 @@ func (m *DuplicateManager) buildIndexDupTasks() ([]dupTask, error) { if err != nil { return nil, errors.Trace(err) } - tasks := make([]dupTask, 0, len(keyRanges)) - for _, kr := range keyRanges { - tableID := tablecodec.DecodeTableID(kr.StartKey) - tasks = append(tasks, dupTask{ - KeyRange: kr, - tableID: tableID, - indexInfo: indexInfo, - }) - } + tasks := make([]dupTask, 0, keyRanges.TotalRangeNum()) + keyRanges.ForEachPartition(func(ranges []tidbkv.KeyRange) { + if len(ranges) == 0 { + return + } + tid := tablecodec.DecodeTableID(ranges[0].StartKey) + for _, r := range ranges { + tasks = append(tasks, dupTask{ + KeyRange: r, + tableID: tid, + }) + } + }) return tasks, nil } return nil, nil diff --git a/br/pkg/logutil/logging.go b/br/pkg/logutil/logging.go index 028cfc00e5f43..41b8e135c220f 100644 --- a/br/pkg/logutil/logging.go +++ b/br/pkg/logutil/logging.go @@ -306,3 +306,13 @@ func (rng StringifyRange) String() string { sb.WriteString(")") return sb.String() } + +// StringifyMany returns an array marshaler for a slice of stringers. +func StringifyMany[T fmt.Stringer](items []T) zapcore.ArrayMarshaler { + return zapcore.ArrayMarshalerFunc(func(ae zapcore.ArrayEncoder) error { + for _, item := range items { + ae.AppendString(item.String()) + } + return nil + }) +} diff --git a/br/pkg/streamhelper/BUILD.bazel b/br/pkg/streamhelper/BUILD.bazel index 93e13b1f8d543..83d80e52620ef 100644 --- a/br/pkg/streamhelper/BUILD.bazel +++ b/br/pkg/streamhelper/BUILD.bazel @@ -12,7 +12,6 @@ go_library( "models.go", "prefix_scanner.go", "regioniter.go", - "tsheap.go", ], importpath = "github.com/pingcap/tidb/br/pkg/streamhelper", visibility = ["//visibility:public"], @@ -21,6 +20,7 @@ go_library( "//br/pkg/logutil", "//br/pkg/redact", "//br/pkg/streamhelper/config", + "//br/pkg/streamhelper/spans", "//br/pkg/utils", "//config", "//kv", @@ -29,7 +29,6 @@ go_library( "//util/mathutil", "@com_github_gogo_protobuf//proto", "@com_github_golang_protobuf//proto", - "@com_github_google_btree//:btree", "@com_github_google_uuid//:uuid", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_kvproto//pkg/brpb", @@ -44,7 +43,6 @@ go_library( "@org_golang_google_grpc//keepalive", "@org_golang_x_sync//errgroup", "@org_uber_go_zap//:zap", - "@org_uber_go_zap//zapcore", ], ) @@ -56,7 +54,6 @@ go_test( "basic_lib_for_test.go", "integration_test.go", "regioniter_test.go", - "tsheap_test.go", ], flaky = True, race = "on", @@ -68,6 +65,7 @@ go_test( "//br/pkg/redact", "//br/pkg/storage", "//br/pkg/streamhelper/config", + "//br/pkg/streamhelper/spans", "//br/pkg/utils", "//kv", "//tablecodec", diff --git a/br/pkg/streamhelper/advancer.go b/br/pkg/streamhelper/advancer.go index ac01c5167ffc7..60bb2928dc08a 100644 --- a/br/pkg/streamhelper/advancer.go +++ b/br/pkg/streamhelper/advancer.go @@ -3,11 +3,7 @@ package streamhelper import ( - "bytes" "context" - "math" - "reflect" - "sort" "strings" "sync" "time" @@ -17,6 +13,7 @@ import ( "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/streamhelper/config" + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" @@ -60,81 +57,28 @@ type CheckpointAdvancer struct { // once tick begin, this should not be changed for now. cfg config.Config - // the cache of region checkpoints. - // so we can advance only ranges with huge gap. - cache CheckpointsCache - - // the internal state of advancer. - state advancerState // the cached last checkpoint. // if no progress, this cache can help us don't to send useless requests. lastCheckpoint uint64 -} - -// advancerState is the sealed type for the state of advancer. -// the advancer has two stage: full scan and update small tree. -type advancerState interface { - // Note: - // Go doesn't support sealed classes or ADTs currently. - // (it can only be used at generic constraints...) - // Leave it empty for now. - - // ~*fullScan | ~*updateSmallTree -} -// fullScan is the initial state of advancer. -// in this stage, we would "fill" the cache: -// insert ranges that union of them become the full range of task. -type fullScan struct { - fullScanTick int -} - -// updateSmallTree is the "incremental stage" of advancer. -// we have build a "filled" cache, and we can pop a subrange of it, -// try to advance the checkpoint of those ranges. -type updateSmallTree struct { - consistencyCheckTick int + checkpoints *spans.ValueSortedFull + checkpointsMu sync.Mutex } // NewCheckpointAdvancer creates a checkpoint advancer with the env. func NewCheckpointAdvancer(env Env) *CheckpointAdvancer { return &CheckpointAdvancer{ - env: env, - cfg: config.Default(), - cache: NewCheckpoints(), - state: &fullScan{}, + env: env, + cfg: config.Default(), } } -// disableCache removes the cache. -// note this won't lock the checkpoint advancer at `fullScan` state forever, -// you may need to change the config `AdvancingByCache`. -func (c *CheckpointAdvancer) disableCache() { - c.cache = NoOPCheckpointCache{} - c.state = &fullScan{} -} - -// enable the cache. -// also check `AdvancingByCache` in the config. -func (c *CheckpointAdvancer) enableCache() { - c.cache = NewCheckpoints() - c.state = &fullScan{} -} - // UpdateConfig updates the config for the advancer. // Note this should be called before starting the loop, because there isn't locks, // TODO: support updating config when advancer starts working. // (Maybe by applying changes at begin of ticking, and add locks.) func (c *CheckpointAdvancer) UpdateConfig(newConf config.Config) { - needRefreshCache := newConf.AdvancingByCache != c.cfg.AdvancingByCache c.cfg = newConf - if needRefreshCache { - if c.cfg.AdvancingByCache { - c.enableCache() - } else { - c.disableCache() - } - } } // UpdateConfigWith updates the config by modifying the current config. @@ -183,28 +127,24 @@ func (c *CheckpointAdvancer) recordTimeCost(message string, fields ...zap.Field) } // tryAdvance tries to advance the checkpoint ts of a set of ranges which shares the same checkpoint. -func (c *CheckpointAdvancer) tryAdvance(ctx context.Context, rst RangesSharesTS) (err error) { - defer c.recordTimeCost("try advance", zap.Uint64("checkpoint", rst.TS), zap.Int("len", len(rst.Ranges)))() - defer func() { - if err != nil { - log.Warn("failed to advance", logutil.ShortError(err), zap.Object("target", rst.Zap())) - c.cache.InsertRanges(rst) - } - }() +func (c *CheckpointAdvancer) tryAdvance(ctx context.Context, length int, getRange func(int) kv.KeyRange) (err error) { + defer c.recordTimeCost("try advance", zap.Int("len", length))() defer utils.PanicToErr(&err) - ranges := CollapseRanges(len(rst.Ranges), func(i int) kv.KeyRange { - return rst.Ranges[i] - }) - workers := utils.NewWorkerPool(4, "sub ranges") + ranges := spans.Collapse(length, getRange) + workers := utils.NewWorkerPool(uint(config.DefaultMaxConcurrencyAdvance)*4, "sub ranges") eg, cx := errgroup.WithContext(ctx) collector := NewClusterCollector(ctx, c.env) - collector.setOnSuccessHook(c.cache.InsertRange) + collector.setOnSuccessHook(func(u uint64, kr kv.KeyRange) { + c.checkpointsMu.Lock() + defer c.checkpointsMu.Unlock() + c.checkpoints.Merge(spans.Valued{Key: kr, Value: u}) + }) clampedRanges := utils.IntersectAll(ranges, utils.CloneSlice(c.taskRange)) for _, r := range clampedRanges { r := r workers.ApplyOnErrorGroup(eg, func() (e error) { - defer c.recordTimeCost("get regions in range", zap.Uint64("checkpoint", rst.TS))() + defer c.recordTimeCost("get regions in range")() defer utils.PanicToErr(&e) return c.GetCheckpointInRange(cx, r.StartKey, r.EndKey, collector) }) @@ -214,121 +154,44 @@ func (c *CheckpointAdvancer) tryAdvance(ctx context.Context, rst RangesSharesTS) return err } - result, err := collector.Finish(ctx) + _, err = collector.Finish(ctx) if err != nil { return err } - fr := result.FailureSubRanges - if len(fr) != 0 { - log.Debug("failure regions collected", zap.Int("size", len(fr))) - c.cache.InsertRanges(RangesSharesTS{ - TS: rst.TS, - Ranges: fr, - }) - } return nil } +func tsoBefore(n time.Duration) uint64 { + now := time.Now() + return oracle.ComposeTS(now.UnixMilli()-n.Milliseconds(), 0) +} + // CalculateGlobalCheckpointLight tries to advance the global checkpoint by the cache. func (c *CheckpointAdvancer) CalculateGlobalCheckpointLight(ctx context.Context) (uint64, error) { - log.Info("[log backup advancer hint] advancer with cache: current tree", zap.Stringer("ct", c.cache)) - rsts := c.cache.PopRangesWithGapGT(config.DefaultTryAdvanceThreshold) - if len(rsts) == 0 { + var targets []spans.Valued + c.checkpoints.TraverseValuesLessThan(tsoBefore(config.DefaultTryAdvanceThreshold), func(v spans.Valued) bool { + targets = append(targets, v) + return true + }) + if len(targets) == 0 { return 0, nil } - samples := rsts - if len(rsts) > 3 { - samples = rsts[:3] + samples := targets + if len(targets) > 3 { + samples = targets[:3] } for _, sample := range samples { - log.Info("[log backup advancer hint] sample range.", zap.Object("range", sample.Zap()), zap.Int("total-len", len(rsts))) + log.Info("[log backup advancer hint] sample range.", zap.Stringer("sample", sample), zap.Int("total-len", len(targets))) } - workers := utils.NewWorkerPool(uint(config.DefaultMaxConcurrencyAdvance), "regions") - eg, cx := errgroup.WithContext(ctx) - for _, rst := range rsts { - rst := rst - workers.ApplyOnErrorGroup(eg, func() (err error) { - return c.tryAdvance(cx, *rst) - }) - } - err := eg.Wait() + err := c.tryAdvance(ctx, len(targets), func(i int) kv.KeyRange { return targets[i].Key }) if err != nil { return 0, err } - ts := c.cache.CheckpointTS() + ts := c.checkpoints.MinValue() return ts, nil } -// CalculateGlobalCheckpoint calculates the global checkpoint, which won't use the cache. -func (c *CheckpointAdvancer) CalculateGlobalCheckpoint(ctx context.Context) (uint64, error) { - var ( - cp = uint64(math.MaxInt64) - thisRun []kv.KeyRange = c.taskRange - nextRun []kv.KeyRange - ) - defer c.recordTimeCost("record all") - for { - coll := NewClusterCollector(ctx, c.env) - coll.setOnSuccessHook(c.cache.InsertRange) - for _, u := range thisRun { - err := c.GetCheckpointInRange(ctx, u.StartKey, u.EndKey, coll) - if err != nil { - return 0, err - } - } - result, err := coll.Finish(ctx) - if err != nil { - return 0, err - } - log.Debug("full: a run finished", zap.Any("checkpoint", result)) - - nextRun = append(nextRun, result.FailureSubRanges...) - if cp > result.Checkpoint { - cp = result.Checkpoint - } - if len(nextRun) == 0 { - return cp, nil - } - thisRun = nextRun - nextRun = nil - log.Debug("backoffing with subranges", zap.Int("subranges", len(thisRun))) - time.Sleep(c.cfg.BackoffTime) - } -} - -// CollapseRanges collapse ranges overlapping or adjacent. -// Example: -// CollapseRanges({[1, 4], [2, 8], [3, 9]}) == {[1, 9]} -// CollapseRanges({[1, 3], [4, 7], [2, 3]}) == {[1, 3], [4, 7]} -func CollapseRanges(length int, getRange func(int) kv.KeyRange) []kv.KeyRange { - frs := make([]kv.KeyRange, 0, length) - for i := 0; i < length; i++ { - frs = append(frs, getRange(i)) - } - - sort.Slice(frs, func(i, j int) bool { - return bytes.Compare(frs[i].StartKey, frs[j].StartKey) < 0 - }) - - result := make([]kv.KeyRange, 0, len(frs)) - i := 0 - for i < len(frs) { - item := frs[i] - for { - i++ - if i >= len(frs) || (len(item.EndKey) != 0 && bytes.Compare(frs[i].StartKey, item.EndKey) > 0) { - break - } - if len(item.EndKey) != 0 && bytes.Compare(item.EndKey, frs[i].EndKey) < 0 || len(frs[i].EndKey) == 0 { - item.EndKey = frs[i].EndKey - } - } - result = append(result, item) - } - return result -} - func (c *CheckpointAdvancer) consumeAllTask(ctx context.Context, ch <-chan TaskEvent) error { for { select { @@ -414,18 +277,18 @@ func (c *CheckpointAdvancer) onTaskEvent(ctx context.Context, e TaskEvent) error case EventAdd: utils.LogBackupTaskCountInc() c.task = e.Info - c.taskRange = CollapseRanges(len(e.Ranges), func(i int) kv.KeyRange { return e.Ranges[i] }) + c.taskRange = spans.Collapse(len(e.Ranges), func(i int) kv.KeyRange { return e.Ranges[i] }) + c.checkpoints = spans.Sorted(spans.NewFullWith(e.Ranges, 0)) log.Info("added event", zap.Stringer("task", e.Info), zap.Stringer("ranges", logutil.StringifyKeys(c.taskRange))) case EventDel: utils.LogBackupTaskCountDec() c.task = nil c.taskRange = nil - c.state = &fullScan{} + c.checkpoints = nil if err := c.env.ClearV3GlobalCheckpointForTask(ctx, e.Name); err != nil { log.Warn("failed to clear global checkpoint", logutil.ShortError(err)) } metrics.LastCheckpoint.DeleteLabelValues(e.Name) - c.cache.Clear() case EventErr: return e.Err } @@ -460,58 +323,17 @@ func (c *CheckpointAdvancer) advanceCheckpointBy(ctx context.Context, getCheckpo return nil } -func (c *CheckpointAdvancer) onConsistencyCheckTick(s *updateSmallTree) error { - if s.consistencyCheckTick > 0 { - s.consistencyCheckTick-- +func (c *CheckpointAdvancer) tick(ctx context.Context) error { + c.taskMu.Lock() + defer c.taskMu.Unlock() + if c.task == nil { + log.Debug("No tasks yet, skipping advancing.") return nil } - defer c.recordTimeCost("consistency check")() - err := c.cache.ConsistencyCheck(c.taskRange) + err := c.advanceCheckpointBy(ctx, c.CalculateGlobalCheckpointLight) if err != nil { - log.Error("consistency check failed! log backup may lose data! rolling back to full scan for saving.", logutil.ShortError(err)) - c.state = &fullScan{} return err } - log.Debug("consistency check passed.") - s.consistencyCheckTick = config.DefaultConsistencyCheckTick - return nil -} -func (c *CheckpointAdvancer) tick(ctx context.Context) error { - c.taskMu.Lock() - defer c.taskMu.Unlock() - - switch s := c.state.(type) { - case *fullScan: - if s.fullScanTick > 0 { - s.fullScanTick-- - break - } - if c.task == nil { - log.Debug("No tasks yet, skipping advancing.") - return nil - } - defer func() { - s.fullScanTick = c.cfg.FullScanTick - }() - err := c.advanceCheckpointBy(ctx, c.CalculateGlobalCheckpoint) - if err != nil { - return err - } - - if c.cfg.AdvancingByCache { - c.state = &updateSmallTree{} - } - case *updateSmallTree: - if err := c.onConsistencyCheckTick(s); err != nil { - return err - } - err := c.advanceCheckpointBy(ctx, c.CalculateGlobalCheckpointLight) - if err != nil { - return err - } - default: - log.Error("Unknown state type, skipping tick", zap.Stringer("type", reflect.TypeOf(c.state))) - } return nil } diff --git a/br/pkg/streamhelper/basic_lib_for_test.go b/br/pkg/streamhelper/basic_lib_for_test.go index b41d5baf19528..9b73745ef65d3 100644 --- a/br/pkg/streamhelper/basic_lib_for_test.go +++ b/br/pkg/streamhelper/basic_lib_for_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/kv" "go.uber.org/zap" @@ -82,16 +83,6 @@ type fakeCluster struct { onGetClient func(uint64) error } -func overlaps(a, b kv.KeyRange) bool { - if len(b.EndKey) == 0 { - return len(a.EndKey) == 0 || bytes.Compare(a.EndKey, b.StartKey) > 0 - } - if len(a.EndKey) == 0 { - return len(b.EndKey) == 0 || bytes.Compare(b.EndKey, a.StartKey) > 0 - } - return bytes.Compare(a.StartKey, b.EndKey) < 0 && bytes.Compare(b.StartKey, a.EndKey) < 0 -} - func (r *region) splitAt(newID uint64, k string) *region { newRegion := ®ion{ rng: kv.KeyRange{StartKey: []byte(k), EndKey: r.rng.EndKey}, @@ -178,7 +169,7 @@ func (f *fakeCluster) RegionScan(ctx context.Context, key []byte, endKey []byte, result := make([]streamhelper.RegionWithLeader, 0, limit) for _, region := range f.regions { - if overlaps(kv.KeyRange{StartKey: key, EndKey: endKey}, region.rng) && len(result) < limit { + if spans.Overlaps(kv.KeyRange{StartKey: key, EndKey: endKey}, region.rng) && len(result) < limit { regionInfo := streamhelper.RegionWithLeader{ Region: &metapb.Region{ Id: region.id, diff --git a/br/pkg/streamhelper/regioniter_test.go b/br/pkg/streamhelper/regioniter_test.go index 04ccc04da8a66..c8281d7a5f33b 100644 --- a/br/pkg/streamhelper/regioniter_test.go +++ b/br/pkg/streamhelper/regioniter_test.go @@ -13,6 +13,7 @@ import ( "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/redact" "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" "github.com/pingcap/tidb/kv" "github.com/stretchr/testify/require" ) @@ -55,7 +56,7 @@ func (c constantRegions) String() string { func (c constantRegions) RegionScan(ctx context.Context, key []byte, endKey []byte, limit int) ([]streamhelper.RegionWithLeader, error) { result := make([]streamhelper.RegionWithLeader, 0, limit) for _, region := range c { - if overlaps(kv.KeyRange{StartKey: key, EndKey: endKey}, kv.KeyRange{StartKey: region.Region.StartKey, EndKey: region.Region.EndKey}) && len(result) < limit { + if spans.Overlaps(kv.KeyRange{StartKey: key, EndKey: endKey}, kv.KeyRange{StartKey: region.Region.StartKey, EndKey: region.Region.EndKey}) && len(result) < limit { result = append(result, region) } else if bytes.Compare(region.Region.StartKey, key) > 0 { break diff --git a/br/pkg/streamhelper/spans/BUILD.bazel b/br/pkg/streamhelper/spans/BUILD.bazel new file mode 100644 index 0000000000000..899f6f6ade6b1 --- /dev/null +++ b/br/pkg/streamhelper/spans/BUILD.bazel @@ -0,0 +1,31 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "spans", + srcs = [ + "sorted.go", + "utils.go", + "value_sorted.go", + ], + importpath = "github.com/pingcap/tidb/br/pkg/streamhelper/spans", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/logutil", + "//br/pkg/utils", + "//kv", + "@com_github_google_btree//:btree", + ], +) + +go_test( + name = "spans_test", + srcs = [ + "sorted_test.go", + "utils_test.go", + "value_sorted_test.go", + ], + deps = [ + ":spans", + "@com_github_stretchr_testify//require", + ], +) diff --git a/br/pkg/streamhelper/spans/sorted.go b/br/pkg/streamhelper/spans/sorted.go new file mode 100644 index 0000000000000..7b9692f529e5b --- /dev/null +++ b/br/pkg/streamhelper/spans/sorted.go @@ -0,0 +1,159 @@ +package spans + +import ( + "bytes" + "fmt" + + "github.com/google/btree" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/kv" +) + +// Value is the value type of stored in the span tree. +type Value = uint64 + +// join finds the upper bound of two values. +func join(a, b Value) Value { + if a > b { + return a + } + return b +} + +// Span is the type of an adjacent sub key space. +type Span = kv.KeyRange + +// Valued is span binding to a value, which is the entry type of span tree. +type Valued struct { + Key Span + Value Value +} + +func (r Valued) String() string { + return fmt.Sprintf("(%s, %d)", logutil.StringifyRange(r.Key), r.Value) +} + +func (r Valued) Less(other btree.Item) bool { + return bytes.Compare(r.Key.StartKey, other.(Valued).Key.StartKey) < 0 +} + +// ValuedFull represents a set of valued ranges, which doesn't overlap and union of them all is the full key space. +type ValuedFull struct { + inner *btree.BTree +} + +// NewFullWith creates a set of a subset of spans. +func NewFullWith(initSpans []Span, init Value) *ValuedFull { + t := btree.New(16) + for _, r := range Collapse(len(initSpans), func(i int) Span { return initSpans[i] }) { + t.ReplaceOrInsert(Valued{Value: init, Key: r}) + } + return &ValuedFull{inner: t} +} + +func (f *ValuedFull) Merge(val Valued) { + overlaps := make([]Valued, 0, 16) + f.overlapped(val.Key, &overlaps) + f.mergeWithOverlap(val, overlaps, nil) +} + +func (f *ValuedFull) Traverse(m func(Valued) bool) { + f.inner.Ascend(func(item btree.Item) bool { + return m(item.(Valued)) + }) +} + +func (f *ValuedFull) mergeWithOverlap(val Valued, overlapped []Valued, newItems *[]Valued) { + // There isn't any range overlaps with the input range, perhaps the input range is empty. + // do nothing for this case. + if len(overlapped) == 0 { + return + } + + for _, r := range overlapped { + f.inner.Delete(r) + // Assert All overlapped ranges are deleted. + } + + var ( + initialized = false + collected Valued + rightTrail *Valued + flushCollected = func() { + if initialized { + f.inner.ReplaceOrInsert(collected) + if newItems != nil { + *newItems = append(*newItems, collected) + } + } + } + emitToCollected = func(rng Valued, standalone bool) { + merged := rng.Value + if !standalone { + merged = join(val.Value, rng.Value) + } + if !initialized { + collected = rng + collected.Value = merged + initialized = true + return + } + if merged == collected.Value && utils.CompareBytesExt(collected.Key.EndKey, true, rng.Key.StartKey, false) == 0 { + collected.Key.EndKey = rng.Key.EndKey + } else { + flushCollected() + collected = Valued{ + Key: rng.Key, + Value: merged, + } + } + } + ) + + leftmost := overlapped[0] + if bytes.Compare(leftmost.Key.StartKey, val.Key.StartKey) < 0 { + emitToCollected(Valued{ + Key: Span{StartKey: leftmost.Key.StartKey, EndKey: val.Key.StartKey}, + Value: leftmost.Value, + }, true) + overlapped[0].Key.StartKey = val.Key.StartKey + } + + rightmost := overlapped[len(overlapped)-1] + if utils.CompareBytesExt(rightmost.Key.EndKey, true, val.Key.EndKey, true) > 0 { + rightTrail = &Valued{ + Key: Span{StartKey: val.Key.EndKey, EndKey: rightmost.Key.EndKey}, + Value: rightmost.Value, + } + overlapped[len(overlapped)-1].Key.EndKey = val.Key.EndKey + } + + for _, rng := range overlapped { + emitToCollected(rng, false) + } + + if rightTrail != nil { + emitToCollected(*rightTrail, true) + } + + flushCollected() +} + +// overlapped inserts the overlapped ranges of the span into the `result` slice. +func (f *ValuedFull) overlapped(k Span, result *[]Valued) { + var first Span + f.inner.DescendLessOrEqual(Valued{Key: k}, func(item btree.Item) bool { + first = item.(Valued).Key + return false + }) + + f.inner.AscendGreaterOrEqual(Valued{Key: first}, func(item btree.Item) bool { + r := item.(Valued) + if !Overlaps(r.Key, k) { + return false + } + *result = append(*result, r) + return true + }) +} diff --git a/br/pkg/streamhelper/spans/sorted_test.go b/br/pkg/streamhelper/spans/sorted_test.go new file mode 100644 index 0000000000000..4cea720577e59 --- /dev/null +++ b/br/pkg/streamhelper/spans/sorted_test.go @@ -0,0 +1,172 @@ +package spans_test + +import ( + "fmt" + "testing" + + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" + "github.com/stretchr/testify/require" +) + +func s(a, b string) spans.Span { + return spans.Span{ + StartKey: []byte(a), + EndKey: []byte(b), + } +} + +func kv(s spans.Span, v spans.Value) spans.Valued { + return spans.Valued{ + Key: s, + Value: v, + } +} + +func TestBasic(t *testing.T) { + type Case struct { + InputSequence []spans.Valued + Result []spans.Valued + } + + run := func(t *testing.T, c Case) { + full := spans.NewFullWith(spans.Full(), 0) + fmt.Println(t.Name()) + for _, i := range c.InputSequence { + full.Merge(i) + var result []spans.Valued + full.Traverse(func(v spans.Valued) bool { + result = append(result, v) + return true + }) + fmt.Printf("%s -> %s\n", i, result) + } + + var result []spans.Valued + full.Traverse(func(v spans.Valued) bool { + result = append(result, v) + return true + }) + + require.True(t, spans.ValuedSetEquals(result, c.Result), "%s\nvs\n%s", result, c.Result) + } + + cases := []Case{ + { + InputSequence: []spans.Valued{ + kv(s("0001", "0002"), 1), + kv(s("0002", "0003"), 2), + }, + Result: []spans.Valued{ + kv(s("", "0001"), 0), + kv(s("0001", "0002"), 1), + kv(s("0002", "0003"), 2), + kv(s("0003", ""), 0), + }, + }, + { + InputSequence: []spans.Valued{ + kv(s("0001", "0002"), 1), + kv(s("0002", "0003"), 2), + kv(s("0001", "0003"), 4), + }, + Result: []spans.Valued{ + kv(s("", "0001"), 0), + kv(s("0001", "0003"), 4), + kv(s("0003", ""), 0), + }, + }, + { + InputSequence: []spans.Valued{ + kv(s("0001", "0004"), 3), + kv(s("0004", "0008"), 5), + kv(s("0001", "0007"), 4), + kv(s("", "0002"), 2), + }, + Result: []spans.Valued{ + kv(s("", "0001"), 2), + kv(s("0001", "0004"), 4), + kv(s("0004", "0008"), 5), + kv(s("0008", ""), 0), + }, + }, + { + InputSequence: []spans.Valued{ + kv(s("0001", "0004"), 3), + kv(s("0004", "0008"), 5), + kv(s("0001", "0009"), 4), + }, + Result: []spans.Valued{ + kv(s("", "0001"), 0), + kv(s("0001", "0004"), 4), + kv(s("0004", "0008"), 5), + kv(s("0008", "0009"), 4), + kv(s("0009", ""), 0), + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("#%d", i+1), func(t *testing.T) { run(t, c) }) + } +} + +func TestSubRange(t *testing.T) { + type Case struct { + Range []spans.Span + InputSequence []spans.Valued + Result []spans.Valued + } + + run := func(t *testing.T, c Case) { + full := spans.NewFullWith(c.Range, 0) + fmt.Println(t.Name()) + for _, i := range c.InputSequence { + full.Merge(i) + var result []spans.Valued + full.Traverse(func(v spans.Valued) bool { + result = append(result, v) + return true + }) + fmt.Printf("%s -> %s\n", i, result) + } + + var result []spans.Valued + full.Traverse(func(v spans.Valued) bool { + result = append(result, v) + return true + }) + + require.True(t, spans.ValuedSetEquals(result, c.Result), "%s\nvs\n%s", result, c.Result) + } + + cases := []Case{ + { + Range: []spans.Span{s("0001", "0004"), s("0008", "")}, + InputSequence: []spans.Valued{ + kv(s("0001", "0007"), 42), + kv(s("0000", "0009"), 41), + kv(s("0002", "0005"), 43), + }, + Result: []spans.Valued{ + kv(s("0001", "0002"), 42), + kv(s("0002", "0004"), 43), + kv(s("0008", "0009"), 41), + kv(s("0009", ""), 0), + }, + }, + { + Range: []spans.Span{ + s("0001", "0004"), + s("0008", "")}, + InputSequence: []spans.Valued{kv(s("", ""), 42)}, + Result: []spans.Valued{ + kv(s("0001", "0004"), 42), + kv(s("0008", ""), 42), + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("#%d", i+1), func(t *testing.T) { run(t, c) }) + } +} diff --git a/br/pkg/streamhelper/spans/utils.go b/br/pkg/streamhelper/spans/utils.go new file mode 100644 index 0000000000000..b6369c58f5840 --- /dev/null +++ b/br/pkg/streamhelper/spans/utils.go @@ -0,0 +1,148 @@ +package spans + +import ( + "bytes" + "fmt" + "math" + "sort" + + "github.com/pingcap/tidb/br/pkg/utils" +) + +// Overlaps checks whether two spans have overlapped part. +func Overlaps(a, b Span) bool { + if len(b.EndKey) == 0 { + return len(a.EndKey) == 0 || bytes.Compare(a.EndKey, b.StartKey) > 0 + } + if len(a.EndKey) == 0 { + return len(b.EndKey) == 0 || bytes.Compare(b.EndKey, a.StartKey) > 0 + } + return bytes.Compare(a.StartKey, b.EndKey) < 0 && bytes.Compare(b.StartKey, a.EndKey) < 0 +} + +func Debug(full *ValueSortedFull) { + var result []Valued + full.Traverse(func(v Valued) bool { + result = append(result, v) + return true + }) + var idx []Valued + full.TraverseValuesLessThan(math.MaxUint64, func(v Valued) bool { + idx = append(idx, v) + return true + }) + fmt.Printf("%s\n\tidx = %s\n", result, idx) +} + +// Collapse collapse ranges overlapping or adjacent. +// Example: +// Collapse({[1, 4], [2, 8], [3, 9]}) == {[1, 9]} +// Collapse({[1, 3], [4, 7], [2, 3]}) == {[1, 3], [4, 7]} +func Collapse(length int, getRange func(int) Span) []Span { + frs := make([]Span, 0, length) + for i := 0; i < length; i++ { + frs = append(frs, getRange(i)) + } + + sort.Slice(frs, func(i, j int) bool { + start := bytes.Compare(frs[i].StartKey, frs[j].StartKey) + if start != 0 { + return start < 0 + } + return utils.CompareBytesExt(frs[i].EndKey, true, frs[j].EndKey, true) < 0 + }) + + result := make([]Span, 0, len(frs)) + i := 0 + for i < len(frs) { + item := frs[i] + for { + i++ + if i >= len(frs) || (len(item.EndKey) != 0 && bytes.Compare(frs[i].StartKey, item.EndKey) > 0) { + break + } + if len(item.EndKey) != 0 && bytes.Compare(item.EndKey, frs[i].EndKey) < 0 || len(frs[i].EndKey) == 0 { + item.EndKey = frs[i].EndKey + } + } + result = append(result, item) + } + return result +} + +// Full returns a full span crossing the key space. +func Full() []Span { + return []Span{{}} +} + +func (x Valued) Equals(y Valued) bool { + return x.Value == y.Value && bytes.Equal(x.Key.StartKey, y.Key.StartKey) && bytes.Equal(x.Key.EndKey, y.Key.EndKey) +} + +func ValuedSetEquals(xs, ys []Valued) bool { + if len(xs) == 0 || len(ys) == 0 { + return len(ys) == len(xs) + } + + sort.Slice(xs, func(i, j int) bool { + start := bytes.Compare(xs[i].Key.StartKey, xs[j].Key.StartKey) + if start != 0 { + return start < 0 + } + return utils.CompareBytesExt(xs[i].Key.EndKey, true, xs[j].Key.EndKey, true) < 0 + }) + sort.Slice(ys, func(i, j int) bool { + start := bytes.Compare(ys[i].Key.StartKey, ys[j].Key.StartKey) + if start != 0 { + return start < 0 + } + return utils.CompareBytesExt(ys[i].Key.EndKey, true, ys[j].Key.EndKey, true) < 0 + }) + + xi := 0 + yi := 0 + + for { + if xi >= len(xs) || yi >= len(ys) { + return (xi >= len(xs)) == (yi >= len(ys)) + } + x := xs[xi] + y := ys[yi] + + if !bytes.Equal(x.Key.StartKey, y.Key.StartKey) { + return false + } + + for { + if xi >= len(xs) || yi >= len(ys) { + return (xi >= len(xs)) == (yi >= len(ys)) + } + x := xs[xi] + y := ys[yi] + + if x.Value != y.Value { + return false + } + + c := utils.CompareBytesExt(x.Key.EndKey, true, y.Key.EndKey, true) + if c == 0 { + xi++ + yi++ + break + } + if c < 0 { + xi++ + // If not adjacent key, return false directly. + if xi < len(xs) && utils.CompareBytesExt(x.Key.EndKey, true, xs[xi].Key.StartKey, false) != 0 { + return false + } + } + if c > 0 { + yi++ + if yi < len(ys) && utils.CompareBytesExt(y.Key.EndKey, true, ys[yi].Key.StartKey, false) != 0 { + return false + } + } + } + } +} diff --git a/br/pkg/streamhelper/spans/utils_test.go b/br/pkg/streamhelper/spans/utils_test.go new file mode 100644 index 0000000000000..0e591d3143ec9 --- /dev/null +++ b/br/pkg/streamhelper/spans/utils_test.go @@ -0,0 +1,81 @@ +package spans_test + +import ( + "fmt" + "testing" + + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" + "github.com/stretchr/testify/require" +) + +func TestValuedEquals(t *testing.T) { + s := func(start, end string, val spans.Value) spans.Valued { + return spans.Valued{ + Key: spans.Span{ + StartKey: []byte(start), + EndKey: []byte(end), + }, + Value: val, + } + } + type Case struct { + inputA []spans.Valued + inputB []spans.Valued + required bool + } + cases := []Case{ + { + inputA: []spans.Valued{s("0001", "0002", 3)}, + inputB: []spans.Valued{s("0001", "0003", 3)}, + required: false, + }, + { + inputA: []spans.Valued{s("0001", "0002", 3)}, + inputB: []spans.Valued{s("0001", "0002", 3)}, + required: true, + }, + { + inputA: []spans.Valued{s("0001", "0003", 3)}, + inputB: []spans.Valued{s("0001", "0002", 3), s("0002", "0003", 3)}, + required: true, + }, + { + inputA: []spans.Valued{s("0001", "0003", 4)}, + inputB: []spans.Valued{s("0001", "0002", 3), s("0002", "0003", 3)}, + required: false, + }, + { + inputA: []spans.Valued{s("0001", "0003", 3)}, + inputB: []spans.Valued{s("0001", "0002", 4), s("0002", "0003", 3)}, + required: false, + }, + { + inputA: []spans.Valued{s("0001", "0003", 3)}, + inputB: []spans.Valued{s("0001", "0002", 3), s("0002", "0004", 3)}, + required: false, + }, + { + inputA: []spans.Valued{s("", "0003", 3)}, + inputB: []spans.Valued{s("0001", "0002", 3), s("0002", "0003", 3)}, + required: false, + }, + { + inputA: []spans.Valued{s("0001", "", 1)}, + inputB: []spans.Valued{s("0001", "0003", 1), s("0004", "", 1)}, + required: false, + }, + { + inputA: []spans.Valued{s("0001", "0004", 1), s("0001", "0002", 1)}, + inputB: []spans.Valued{s("0001", "0002", 1), s("0001", "0004", 1)}, + required: true, + }, + } + run := func(t *testing.T, c Case) { + require.Equal(t, c.required, spans.ValuedSetEquals(c.inputA, c.inputB)) + require.Equal(t, c.required, spans.ValuedSetEquals(c.inputB, c.inputA)) + } + + for i, c := range cases { + t.Run(fmt.Sprintf("#%d", i+1), func(t *testing.T) { run(t, c) }) + } +} diff --git a/br/pkg/streamhelper/spans/value_sorted.go b/br/pkg/streamhelper/spans/value_sorted.go new file mode 100644 index 0000000000000..6e2cf942af578 --- /dev/null +++ b/br/pkg/streamhelper/spans/value_sorted.go @@ -0,0 +1,64 @@ +package spans + +import "github.com/google/btree" + +type sortedByValueThenStartKey Valued + +func (s sortedByValueThenStartKey) Less(o btree.Item) bool { + other := o.(sortedByValueThenStartKey) + if s.Value != other.Value { + return s.Value < other.Value + } + return Valued(s).Less(Valued(other)) +} + +type ValueSortedFull struct { + *ValuedFull + valueIdx *btree.BTree +} + +func Sorted(f *ValuedFull) *ValueSortedFull { + vf := &ValueSortedFull{ + ValuedFull: f, + valueIdx: btree.New(16), + } + f.Traverse(func(v Valued) bool { + vf.valueIdx.ReplaceOrInsert(sortedByValueThenStartKey(v)) + return true + }) + return vf +} + +func (v *ValueSortedFull) Merge(newItem Valued) { + v.MergeAll([]Valued{newItem}) +} + +func (v *ValueSortedFull) MergeAll(newItems []Valued) { + var overlapped []Valued + var inserted []Valued + + for _, item := range newItems { + overlapped = overlapped[:0] + inserted = inserted[:0] + + v.overlapped(item.Key, &overlapped) + v.mergeWithOverlap(item, overlapped, &inserted) + + for _, o := range overlapped { + v.valueIdx.Delete(sortedByValueThenStartKey(o)) + } + for _, i := range inserted { + v.valueIdx.ReplaceOrInsert(sortedByValueThenStartKey(i)) + } + } +} + +func (v *ValueSortedFull) TraverseValuesLessThan(n Value, action func(Valued) bool) { + v.valueIdx.AscendLessThan(sortedByValueThenStartKey{Value: n}, func(item btree.Item) bool { + return action(Valued(item.(sortedByValueThenStartKey))) + }) +} + +func (v *ValueSortedFull) MinValue() Value { + return v.valueIdx.Min().(sortedByValueThenStartKey).Value +} diff --git a/br/pkg/streamhelper/spans/value_sorted_test.go b/br/pkg/streamhelper/spans/value_sorted_test.go new file mode 100644 index 0000000000000..36ef744fb4ef4 --- /dev/null +++ b/br/pkg/streamhelper/spans/value_sorted_test.go @@ -0,0 +1,80 @@ +package spans_test + +import ( + "fmt" + "testing" + + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" + "github.com/stretchr/testify/require" +) + +func TestSortedBasic(t *testing.T) { + type Case struct { + InputSequence []spans.Valued + RetainLessThan spans.Value + Result []spans.Valued + } + + run := func(t *testing.T, c Case) { + full := spans.Sorted(spans.NewFullWith(spans.Full(), 0)) + fmt.Println(t.Name()) + for _, i := range c.InputSequence { + full.Merge(i) + spans.Debug(full) + } + + var result []spans.Valued + full.TraverseValuesLessThan(c.RetainLessThan, func(v spans.Valued) bool { + result = append(result, v) + return true + }) + + require.True(t, spans.ValuedSetEquals(result, c.Result), "%s\nvs\n%s", result, c.Result) + } + + cases := []Case{ + { + InputSequence: []spans.Valued{ + kv(s("0001", "0002"), 1), + kv(s("0002", "0003"), 2), + }, + Result: []spans.Valued{ + kv(s("", "0001"), 0), + kv(s("0001", "0002"), 1), + kv(s("0002", "0003"), 2), + kv(s("0003", ""), 0), + }, + RetainLessThan: 10, + }, + { + InputSequence: []spans.Valued{ + kv(s("0001", "0002"), 1), + kv(s("0002", "0003"), 2), + kv(s("0001", "0003"), 4), + }, + RetainLessThan: 1, + Result: []spans.Valued{ + kv(s("", "0001"), 0), + kv(s("0003", ""), 0), + }, + }, + { + InputSequence: []spans.Valued{ + kv(s("0001", "0004"), 3), + kv(s("0004", "0008"), 5), + kv(s("0001", "0007"), 4), + kv(s("", "0002"), 2), + }, + RetainLessThan: 5, + Result: []spans.Valued{ + kv(s("", "0001"), 2), + kv(s("0001", "0004"), 4), + kv(s("0008", ""), 0), + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("#%d", i+1), func(t *testing.T) { run(t, c) }) + } +} diff --git a/br/pkg/streamhelper/tsheap.go b/br/pkg/streamhelper/tsheap.go deleted file mode 100644 index 6c2fb510776e7..0000000000000 --- a/br/pkg/streamhelper/tsheap.go +++ /dev/null @@ -1,326 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package streamhelper - -import ( - "encoding/hex" - "fmt" - "strings" - "sync" - "time" - - "github.com/google/btree" - "github.com/pingcap/errors" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/redact" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/kv" - "github.com/tikv/client-go/v2/oracle" - "go.uber.org/zap/zapcore" -) - -// CheckpointsCache is the heap-like cache for checkpoints. -// -// "Checkpoint" is the "Resolved TS" of some range. -// A resolved ts is a "watermark" for the system, which: -// - implies there won't be any transactions (in some range) commit with `commit_ts` smaller than this TS. -// - is monotonic increasing. -// A "checkpoint" is a "safe" Resolved TS, which: -// - is a TS *less than* the real resolved ts of now. -// - is based on range (it only promises there won't be new committed txns in the range). -// - the checkpoint of union of ranges is the minimal checkpoint of all ranges. -// As an example: -/* - +----------------------------------+ - ^-----------^ (Checkpoint = 42) - ^---------------^ (Checkpoint = 76) - ^-----------------------^ (Checkpoint = min(42, 76) = 42) -*/ -// For calculating the global checkpoint, we can make a heap-like structure: -// Checkpoint Ranges -// 42 -> {[0, 8], [16, 100]} -// 1002 -> {[8, 16]} -// 1082 -> {[100, inf]} -// For now, the checkpoint of range [8, 16] and [100, inf] won't affect the global checkpoint -// directly, so we can try to advance only the ranges of {[0, 8], [16, 100]} (which's checkpoint is steal). -// Once them get advance, the global checkpoint would be advanced then, -// and we don't need to update all ranges (because some new ranges don't need to be advanced so quickly.) -type CheckpointsCache interface { - fmt.Stringer - // InsertRange inserts a range with specified TS to the cache. - InsertRange(ts uint64, rng kv.KeyRange) - // InsertRanges inserts a set of ranges that sharing checkpoint to the cache. - InsertRanges(rst RangesSharesTS) - // CheckpointTS returns the now global (union of all ranges) checkpoint of the cache. - CheckpointTS() uint64 - // PopRangesWithGapGT pops the ranges which's checkpoint is - PopRangesWithGapGT(d time.Duration) []*RangesSharesTS - // Check whether the ranges in the cache is integrate. - ConsistencyCheck(ranges []kv.KeyRange) error - // Clear the cache. - Clear() -} - -// NoOPCheckpointCache is used when cache disabled. -type NoOPCheckpointCache struct{} - -func (NoOPCheckpointCache) InsertRange(ts uint64, rng kv.KeyRange) {} - -func (NoOPCheckpointCache) InsertRanges(rst RangesSharesTS) {} - -func (NoOPCheckpointCache) Clear() {} - -func (NoOPCheckpointCache) String() string { - return "NoOPCheckpointCache" -} - -func (NoOPCheckpointCache) CheckpointTS() uint64 { - panic("invalid state: NoOPCheckpointCache should never be used in advancing!") -} - -func (NoOPCheckpointCache) PopRangesWithGapGT(d time.Duration) []*RangesSharesTS { - panic("invalid state: NoOPCheckpointCache should never be used in advancing!") -} - -func (NoOPCheckpointCache) ConsistencyCheck([]kv.KeyRange) error { - return errors.Annotatef(berrors.ErrUnsupportedOperation, "invalid state: NoOPCheckpointCache should never be used in advancing!") -} - -// RangesSharesTS is a set of ranges shares the same timestamp. -type RangesSharesTS struct { - TS uint64 - Ranges []kv.KeyRange -} - -func (rst *RangesSharesTS) Zap() zapcore.ObjectMarshaler { - return zapcore.ObjectMarshalerFunc(func(oe zapcore.ObjectEncoder) error { - rngs := rst.Ranges - if len(rst.Ranges) > 3 { - rngs = rst.Ranges[:3] - } - - oe.AddUint64("checkpoint", rst.TS) - return oe.AddArray("items", zapcore.ArrayMarshalerFunc(func(ae zapcore.ArrayEncoder) error { - return ae.AppendObject(zapcore.ObjectMarshalerFunc(func(oe1 zapcore.ObjectEncoder) error { - for _, rng := range rngs { - oe1.AddString("start-key", redact.String(hex.EncodeToString(rng.StartKey))) - oe1.AddString("end-key", redact.String(hex.EncodeToString(rng.EndKey))) - } - return nil - })) - })) - }) -} - -func (rst *RangesSharesTS) String() string { - // Make a more friendly string. - return fmt.Sprintf("@%sR%d", oracle.GetTimeFromTS(rst.TS).Format("0405"), len(rst.Ranges)) -} - -func (rst *RangesSharesTS) Less(other btree.Item) bool { - return rst.TS < other.(*RangesSharesTS).TS -} - -// Checkpoints is a heap that collects all checkpoints of -// regions, it supports query the latest checkpoint fast. -// This structure is thread safe. -type Checkpoints struct { - tree *btree.BTree - - mu sync.Mutex -} - -func NewCheckpoints() *Checkpoints { - return &Checkpoints{ - tree: btree.New(32), - } -} - -// String formats the slowest 5 ranges sharing TS to string. -func (h *Checkpoints) String() string { - h.mu.Lock() - defer h.mu.Unlock() - - b := new(strings.Builder) - count := 0 - total := h.tree.Len() - h.tree.Ascend(func(i btree.Item) bool { - rst := i.(*RangesSharesTS) - b.WriteString(rst.String()) - b.WriteString(";") - count++ - return count < 5 - }) - if total-count > 0 { - fmt.Fprintf(b, "O%d", total-count) - } - return b.String() -} - -// InsertRanges insert a RangesSharesTS directly to the tree. -func (h *Checkpoints) InsertRanges(r RangesSharesTS) { - h.mu.Lock() - defer h.mu.Unlock() - if items := h.tree.Get(&r); items != nil { - i := items.(*RangesSharesTS) - i.Ranges = append(i.Ranges, r.Ranges...) - } else { - h.tree.ReplaceOrInsert(&r) - } -} - -// InsertRange inserts the region and its TS into the region tree. -func (h *Checkpoints) InsertRange(ts uint64, rng kv.KeyRange) { - h.mu.Lock() - defer h.mu.Unlock() - r := h.tree.Get(&RangesSharesTS{TS: ts}) - if r == nil { - r = &RangesSharesTS{TS: ts} - h.tree.ReplaceOrInsert(r) - } - rr := r.(*RangesSharesTS) - rr.Ranges = append(rr.Ranges, rng) -} - -// Clear removes all records in the checkpoint cache. -func (h *Checkpoints) Clear() { - h.mu.Lock() - defer h.mu.Unlock() - h.tree.Clear(false) -} - -// PopRangesWithGapGT pops ranges with gap greater than the specified duration. -// NOTE: maybe make something like `DrainIterator` for better composing? -func (h *Checkpoints) PopRangesWithGapGT(d time.Duration) []*RangesSharesTS { - h.mu.Lock() - defer h.mu.Unlock() - result := []*RangesSharesTS{} - for { - item, ok := h.tree.Min().(*RangesSharesTS) - if !ok { - return result - } - if time.Since(oracle.GetTimeFromTS(item.TS)) >= d { - result = append(result, item) - h.tree.DeleteMin() - } else { - return result - } - } -} - -// CheckpointTS returns the cached checkpoint TS by the current state of the cache. -func (h *Checkpoints) CheckpointTS() uint64 { - h.mu.Lock() - defer h.mu.Unlock() - item, ok := h.tree.Min().(*RangesSharesTS) - if !ok { - return 0 - } - return item.TS -} - -// ConsistencyCheck checks whether the tree contains the full range of key space. -func (h *Checkpoints) ConsistencyCheck(rangesIn []kv.KeyRange) error { - h.mu.Lock() - rangesReal := make([]kv.KeyRange, 0, 1024) - h.tree.Ascend(func(i btree.Item) bool { - rangesReal = append(rangesReal, i.(*RangesSharesTS).Ranges...) - return true - }) - h.mu.Unlock() - - r := CollapseRanges(len(rangesReal), func(i int) kv.KeyRange { return rangesReal[i] }) - ri := CollapseRanges(len(rangesIn), func(i int) kv.KeyRange { return rangesIn[i] }) - - return errors.Annotatef(checkIntervalIsSubset(r, ri), "ranges: (current) %s (not in) %s", logutil.StringifyKeys(r), - logutil.StringifyKeys(ri)) -} - -// A simple algorithm to detect non-overlapped ranges. -// It maintains the "current" probe, and let the ranges to check "consume" it. -// For example: -// toCheck: |_____________________| |_____________| -// . ^checking -// subsetOf: |_________| |_______| |__________| -// . ^probing -// probing is the subrange of checking, consume it and move forward the probe. -// toCheck: |_____________________| |_____________| -// . ^checking -// subsetOf: |_________| |_______| |__________| -// . ^probing -// consume it, too. -// toCheck: |_____________________| |_____________| -// . ^checking -// subsetOf: |_________| |_______| |__________| -// . ^probing -// checking is at the left of probing and no overlaps, moving it forward. -// toCheck: |_____________________| |_____________| -// . ^checking -// subsetOf: |_________| |_______| |__________| -// . ^probing -// consume it. all subset ranges are consumed, check passed. -func checkIntervalIsSubset(toCheck []kv.KeyRange, subsetOf []kv.KeyRange) error { - i := 0 - si := 0 - - for { - // We have checked all ranges. - if si >= len(subsetOf) { - return nil - } - // There are some ranges doesn't reach the end. - if i >= len(toCheck) { - return errors.Annotatef(berrors.ErrPiTRMalformedMetadata, - "there remains a range doesn't be fully consumed: %s", - logutil.StringifyRange(subsetOf[si])) - } - - checking := toCheck[i] - probing := subsetOf[si] - // checking: |___________| - // probing: |_________| - // A rare case: the "first" range is out of bound or not fully covers the probing range. - if utils.CompareBytesExt(checking.StartKey, false, probing.StartKey, false) > 0 { - holeEnd := checking.StartKey - if utils.CompareBytesExt(holeEnd, false, probing.EndKey, true) > 0 { - holeEnd = probing.EndKey - } - return errors.Annotatef(berrors.ErrPiTRMalformedMetadata, "probably a hole in key ranges: %s", logutil.StringifyRange{ - StartKey: probing.StartKey, - EndKey: holeEnd, - }) - } - - // checking: |_____| - // probing: |_______| - // Just move forward checking. - if utils.CompareBytesExt(checking.EndKey, true, probing.StartKey, false) < 0 { - i += 1 - continue - } - - // checking: |_________| - // probing: |__________________| - // Given all of the ranges are "collapsed", the next checking range must - // not be adjacent with the current checking range. - // And hence there must be a "hole" in the probing key space. - if utils.CompareBytesExt(checking.EndKey, true, probing.EndKey, true) < 0 { - next := probing.EndKey - if i+1 < len(toCheck) { - next = toCheck[i+1].EndKey - } - return errors.Annotatef(berrors.ErrPiTRMalformedMetadata, "probably a hole in key ranges: %s", logutil.StringifyRange{ - StartKey: checking.EndKey, - EndKey: next, - }) - } - // checking: |________________| - // probing: |_____________| - // The current checking range fills the current probing range, - // or the current checking range is out of the current range. - // let's move the probing forward. - si += 1 - } -} diff --git a/br/pkg/streamhelper/tsheap_test.go b/br/pkg/streamhelper/tsheap_test.go deleted file mode 100644 index 173bc2e0a0334..0000000000000 --- a/br/pkg/streamhelper/tsheap_test.go +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. -package streamhelper_test - -import ( - "fmt" - "math" - "math/rand" - "testing" - - "github.com/pingcap/tidb/br/pkg/streamhelper" - "github.com/pingcap/tidb/kv" - "github.com/stretchr/testify/require" -) - -func TestInsert(t *testing.T) { - cases := []func(func(ts uint64, a, b string)){ - func(insert func(ts uint64, a, b string)) { - insert(1, "", "01") - insert(1, "01", "02") - insert(2, "02", "022") - insert(4, "022", "") - }, - func(insert func(ts uint64, a, b string)) { - insert(1, "", "01") - insert(2, "", "01") - insert(2, "011", "02") - insert(1, "", "") - insert(65, "03", "04") - }, - } - - for _, c := range cases { - cps := streamhelper.NewCheckpoints() - expected := map[uint64]*streamhelper.RangesSharesTS{} - checkpoint := uint64(math.MaxUint64) - insert := func(ts uint64, a, b string) { - cps.InsertRange(ts, kv.KeyRange{ - StartKey: []byte(a), - EndKey: []byte(b), - }) - i, ok := expected[ts] - if !ok { - expected[ts] = &streamhelper.RangesSharesTS{TS: ts, Ranges: []kv.KeyRange{{StartKey: []byte(a), EndKey: []byte(b)}}} - } else { - i.Ranges = append(i.Ranges, kv.KeyRange{StartKey: []byte(a), EndKey: []byte(b)}) - } - if ts < checkpoint { - checkpoint = ts - } - } - c(insert) - require.Equal(t, checkpoint, cps.CheckpointTS()) - rngs := cps.PopRangesWithGapGT(0) - for _, rng := range rngs { - other := expected[rng.TS] - require.Equal(t, other, rng) - } - } -} - -func TestMergeRanges(t *testing.T) { - r := func(a, b string) kv.KeyRange { - return kv.KeyRange{StartKey: []byte(a), EndKey: []byte(b)} - } - type Case struct { - expected []kv.KeyRange - parameter []kv.KeyRange - } - cases := []Case{ - { - parameter: []kv.KeyRange{r("01", "01111"), r("0111", "0112")}, - expected: []kv.KeyRange{r("01", "0112")}, - }, - { - parameter: []kv.KeyRange{r("01", "03"), r("02", "04")}, - expected: []kv.KeyRange{r("01", "04")}, - }, - { - parameter: []kv.KeyRange{r("04", "08"), r("09", "10")}, - expected: []kv.KeyRange{r("04", "08"), r("09", "10")}, - }, - { - parameter: []kv.KeyRange{r("01", "03"), r("02", "04"), r("05", "07"), r("08", "09")}, - expected: []kv.KeyRange{r("01", "04"), r("05", "07"), r("08", "09")}, - }, - { - parameter: []kv.KeyRange{r("01", "02"), r("012", "")}, - expected: []kv.KeyRange{r("01", "")}, - }, - { - parameter: []kv.KeyRange{r("", "01"), r("02", "03"), r("021", "")}, - expected: []kv.KeyRange{r("", "01"), r("02", "")}, - }, - { - parameter: []kv.KeyRange{r("", "01"), r("001", "")}, - expected: []kv.KeyRange{r("", "")}, - }, - { - parameter: []kv.KeyRange{r("", "01"), r("", ""), r("", "02")}, - expected: []kv.KeyRange{r("", "")}, - }, - { - parameter: []kv.KeyRange{r("", "01"), r("01", ""), r("", "02"), r("", "03"), r("01", "02")}, - expected: []kv.KeyRange{r("", "")}, - }, - { - parameter: []kv.KeyRange{r("", ""), r("", "01"), r("01", ""), r("01", "02")}, - expected: []kv.KeyRange{r("", "")}, - }, - } - - for i, c := range cases { - result := streamhelper.CollapseRanges(len(c.parameter), func(i int) kv.KeyRange { - return c.parameter[i] - }) - require.Equal(t, c.expected, result, "case = %d", i) - } -} - -func TestInsertRanges(t *testing.T) { - r := func(a, b string) kv.KeyRange { - return kv.KeyRange{StartKey: []byte(a), EndKey: []byte(b)} - } - rs := func(ts uint64, ranges ...kv.KeyRange) streamhelper.RangesSharesTS { - return streamhelper.RangesSharesTS{TS: ts, Ranges: ranges} - } - - type Case struct { - Expected []streamhelper.RangesSharesTS - Parameters []streamhelper.RangesSharesTS - } - - cases := []Case{ - { - Parameters: []streamhelper.RangesSharesTS{ - rs(1, r("0", "1"), r("1", "2")), - rs(1, r("2", "3"), r("3", "4")), - }, - Expected: []streamhelper.RangesSharesTS{ - rs(1, r("0", "1"), r("1", "2"), r("2", "3"), r("3", "4")), - }, - }, - { - Parameters: []streamhelper.RangesSharesTS{ - rs(1, r("0", "1")), - rs(2, r("2", "3")), - rs(1, r("4", "5"), r("6", "7")), - }, - Expected: []streamhelper.RangesSharesTS{ - rs(1, r("0", "1"), r("4", "5"), r("6", "7")), - rs(2, r("2", "3")), - }, - }, - } - - for _, c := range cases { - theTree := streamhelper.NewCheckpoints() - for _, p := range c.Parameters { - theTree.InsertRanges(p) - } - ranges := theTree.PopRangesWithGapGT(0) - for i, rs := range ranges { - require.ElementsMatch(t, c.Expected[i].Ranges, rs.Ranges, "case = %#v", c) - } - } -} - -func TestConsistencyCheckOverRange(t *testing.T) { - r := func(a, b string) kv.KeyRange { - return kv.KeyRange{StartKey: []byte(a), EndKey: []byte(b)} - } - type Case struct { - checking []kv.KeyRange - probing []kv.KeyRange - isSubset bool - } - - cases := []Case{ - // basic: exactly match. - { - checking: []kv.KeyRange{r("0001", "0002"), r("0002", "0003"), r("0004", "0005")}, - probing: []kv.KeyRange{r("0001", "0003"), r("0004", "0005")}, - isSubset: true, - }, - // not fully match, probing longer. - { - checking: []kv.KeyRange{r("0001", "0002"), r("0002", "0003"), r("0004", "0005")}, - probing: []kv.KeyRange{r("0000", "0003"), r("0004", "00051")}, - isSubset: false, - }, - // with infinity end keys. - { - checking: []kv.KeyRange{r("0001", "0002"), r("0002", "0003"), r("0004", "")}, - probing: []kv.KeyRange{r("0001", "0003"), r("0004", "")}, - isSubset: true, - }, - { - checking: []kv.KeyRange{r("0001", "0002"), r("0002", "0003"), r("0004", "")}, - probing: []kv.KeyRange{r("0001", "0003"), r("0004", "0005")}, - isSubset: true, - }, - { - checking: []kv.KeyRange{r("0001", "0002"), r("0002", "0003"), r("0004", "0005")}, - probing: []kv.KeyRange{r("0001", "0003"), r("0004", "")}, - isSubset: false, - }, - // overlapped probe. - { - checking: []kv.KeyRange{r("0001", "0002"), r("0002", "0003"), r("0004", "0007")}, - probing: []kv.KeyRange{r("0001", "0008")}, - isSubset: false, - }, - { - checking: []kv.KeyRange{r("0001", "0008")}, - probing: []kv.KeyRange{r("0001", "0002"), r("0002", "0003"), r("0004", "0007")}, - isSubset: true, - }, - { - checking: []kv.KeyRange{r("0100", "0120"), r("0130", "0141")}, - probing: []kv.KeyRange{r("0000", "0001")}, - isSubset: false, - }, - { - checking: []kv.KeyRange{r("0100", "0120")}, - probing: []kv.KeyRange{r("0090", "0110"), r("0115", "0120")}, - isSubset: false, - }, - } - - run := func(t *testing.T, c Case) { - tree := streamhelper.NewCheckpoints() - for _, r := range c.checking { - tree.InsertRange(rand.Uint64()%10, r) - } - err := tree.ConsistencyCheck(c.probing) - if c.isSubset { - require.NoError(t, err) - } else { - require.Error(t, err) - } - } - - for i, c := range cases { - t.Run(fmt.Sprintf("#%d", i), func(tc *testing.T) { - run(tc, c) - }) - } -} diff --git a/br/pkg/utils/key.go b/br/pkg/utils/key.go index 062f4b5aac52d..62d194ca57a2e 100644 --- a/br/pkg/utils/key.go +++ b/br/pkg/utils/key.go @@ -163,7 +163,7 @@ func CloneSlice[T any](s []T) []T { // toClampIn: |_____| |____| |________________| // result: |_____| |_| |______________| // we are assuming the arguments are sorted by the start key and no overlaps. -// you can call CollapseRanges to get key ranges fits this requirements. +// you can call spans.Collapse to get key ranges fits this requirements. // Note: this algorithm is pretty like the `checkIntervalIsSubset`, can we get them together? func IntersectAll(s1 []kv.KeyRange, s2 []kv.KeyRange) []kv.KeyRange { currentClamping := 0 diff --git a/ddl/db_partition_test.go b/ddl/db_partition_test.go index e5ad2aa2bbfec..d714ed716f9f9 100644 --- a/ddl/db_partition_test.go +++ b/ddl/db_partition_test.go @@ -1409,7 +1409,7 @@ func TestAlterTableDropPartitionByList(t *testing.T) { );`) tk.MustExec(`insert into t values (1),(3),(5),(null)`) tk.MustExec(`alter table t drop partition p1`) - tk.MustQuery("select * from t").Sort().Check(testkit.Rows("1", "5", "")) + tk.MustQuery("select * from t order by id").Check(testkit.Rows("", "1", "5")) ctx := tk.Session() is := domain.GetDomain(ctx).InfoSchema() tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) diff --git a/ddl/ddl.go b/ddl/ddl.go index af8a0ca67a8d5..6e6488ca0d1c9 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -1217,8 +1217,10 @@ func (d *ddl) SwitchConcurrentDDL(toConcurrentDDL bool) error { } if err == nil { variable.EnableConcurrentDDL.Store(toConcurrentDDL) + logutil.BgLogger().Info("[ddl] SwitchConcurrentDDL", zap.Bool("toConcurrentDDL", toConcurrentDDL)) + } else { + logutil.BgLogger().Warn("[ddl] SwitchConcurrentDDL", zap.Bool("toConcurrentDDL", toConcurrentDDL), zap.Error(err)) } - logutil.BgLogger().Info("[ddl] SwitchConcurrentDDL", zap.Bool("toConcurrentDDL", toConcurrentDDL), zap.Error(err)) return err } @@ -1279,9 +1281,10 @@ func (d *ddl) SwitchMDL(enable bool) error { return err }) if err != nil { + logutil.BgLogger().Warn("[ddl] switch metadata lock feature", zap.Bool("enable", enable), zap.Error(err)) return err } - logutil.BgLogger().Info("[ddl] switch metadata lock feature", zap.Bool("enable", enable), zap.Error(err)) + logutil.BgLogger().Info("[ddl] switch metadata lock feature", zap.Bool("enable", enable)) return nil } diff --git a/ddl/reorg.go b/ddl/reorg.go index d8b31916fba37..a03cf417177dc 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -234,7 +234,12 @@ func (w *worker) runReorgJob(rh *reorgHandler, reorgInfo *reorgInfo, tblInfo *mo return dbterror.ErrCancelledDDLJob } rowCount, _, _ := rc.getRowCountAndKey() - logutil.BgLogger().Info("[ddl] run reorg job done", zap.Int64("handled rows", rowCount), zap.Error(err)) + if err != nil { + logutil.BgLogger().Warn("[ddl] run reorg job done", zap.Int64("handled rows", rowCount), zap.Error(err)) + } else { + logutil.BgLogger().Info("[ddl] run reorg job done", zap.Int64("handled rows", rowCount)) + } + job.SetRowCount(rowCount) // Update a job's warnings. diff --git a/distsql/request_builder.go b/distsql/request_builder.go index 4a8b3ddfeab13..a293c4d10963e 100644 --- a/distsql/request_builder.go +++ b/distsql/request_builder.go @@ -20,7 +20,6 @@ import ( "sort" "sync/atomic" - "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/ddl/placement" @@ -71,6 +70,9 @@ func (builder *RequestBuilder) Build() (*kv.Request, error) { if err != nil { builder.err = err } + if builder.Request.KeyRanges == nil { + builder.Request.KeyRanges = kv.NewNonParitionedKeyRanges(nil) + } return &builder.Request, builder.err } @@ -86,7 +88,7 @@ func (builder *RequestBuilder) SetMemTracker(tracker *memory.Tracker) *RequestBu // br refers it, so have to keep it. func (builder *RequestBuilder) SetTableRanges(tid int64, tableRanges []*ranger.Range, fb *statistics.QueryFeedback) *RequestBuilder { if builder.err == nil { - builder.Request.KeyRanges = TableRangesToKVRanges(tid, tableRanges, fb) + builder.Request.KeyRanges = kv.NewNonParitionedKeyRanges(TableRangesToKVRanges(tid, tableRanges, fb)) } return builder } @@ -112,7 +114,9 @@ func (builder *RequestBuilder) SetIndexRangesForTables(sc *stmtctx.StatementCont // SetHandleRanges sets "KeyRanges" for "kv.Request" by converting table handle range // "ranges" to "KeyRanges" firstly. func (builder *RequestBuilder) SetHandleRanges(sc *stmtctx.StatementContext, tid int64, isCommonHandle bool, ranges []*ranger.Range, fb *statistics.QueryFeedback) *RequestBuilder { - return builder.SetHandleRangesForTables(sc, []int64{tid}, isCommonHandle, ranges, fb) + builder = builder.SetHandleRangesForTables(sc, []int64{tid}, isCommonHandle, ranges, fb) + builder.err = builder.Request.KeyRanges.SetToNonPartitioned() + return builder } // SetHandleRangesForTables sets "KeyRanges" for "kv.Request" by converting table handle range @@ -127,14 +131,17 @@ func (builder *RequestBuilder) SetHandleRangesForTables(sc *stmtctx.StatementCon // SetTableHandles sets "KeyRanges" for "kv.Request" by converting table handles // "handles" to "KeyRanges" firstly. func (builder *RequestBuilder) SetTableHandles(tid int64, handles []kv.Handle) *RequestBuilder { - builder.Request.KeyRanges, builder.FixedRowCountHint = TableHandlesToKVRanges(tid, handles) + var keyRanges []kv.KeyRange + keyRanges, builder.FixedRowCountHint = TableHandlesToKVRanges(tid, handles) + builder.Request.KeyRanges = kv.NewNonParitionedKeyRanges(keyRanges) return builder } // SetPartitionsAndHandles sets "KeyRanges" for "kv.Request" by converting ParitionHandles to KeyRanges. // handles in slice must be kv.PartitionHandle. func (builder *RequestBuilder) SetPartitionsAndHandles(handles []kv.Handle) *RequestBuilder { - builder.Request.KeyRanges = PartitionHandlesToKVRanges(handles) + keyRanges := PartitionHandlesToKVRanges(handles) + builder.Request.KeyRanges = kv.NewNonParitionedKeyRanges(keyRanges) return builder } @@ -183,10 +190,22 @@ func (builder *RequestBuilder) SetChecksumRequest(checksum *tipb.ChecksumRequest // SetKeyRanges sets "KeyRanges" for "kv.Request". func (builder *RequestBuilder) SetKeyRanges(keyRanges []kv.KeyRange) *RequestBuilder { + builder.Request.KeyRanges = kv.NewNonParitionedKeyRanges(keyRanges) + return builder +} + +// SetWrappedKeyRanges sets "KeyRanges" for "kv.Request". +func (builder *RequestBuilder) SetWrappedKeyRanges(keyRanges *kv.KeyRanges) *RequestBuilder { builder.Request.KeyRanges = keyRanges return builder } +// SetPartitionKeyRanges sets the "KeyRanges" for "kv.Request" on partitioned table cases. +func (builder *RequestBuilder) SetPartitionKeyRanges(keyRanges [][]kv.KeyRange) *RequestBuilder { + builder.Request.KeyRanges = kv.NewPartitionedKeyRanges(keyRanges) + return builder +} + // SetStartTS sets "StartTS" for "kv.Request". func (builder *RequestBuilder) SetStartTS(startTS uint64) *RequestBuilder { builder.Request.StartTs = startTS @@ -318,13 +337,12 @@ func (builder *RequestBuilder) verifyTxnScope() error { return nil } visitPhysicalTableID := make(map[int64]struct{}) - for _, keyRange := range builder.Request.KeyRanges { - tableID := tablecodec.DecodeTableID(keyRange.StartKey) - if tableID > 0 { - visitPhysicalTableID[tableID] = struct{}{} - } else { - return errors.New("requestBuilder can't decode tableID from keyRange") - } + tids, err := tablecodec.VerifyTableIDForRanges(builder.Request.KeyRanges) + if err != nil { + return err + } + for _, tid := range tids { + visitPhysicalTableID[tid] = struct{}{} } for phyTableID := range visitPhysicalTableID { @@ -376,7 +394,7 @@ func (builder *RequestBuilder) SetClosestReplicaReadAdjuster(chkFn kv.CoprReques } // TableHandleRangesToKVRanges convert table handle ranges to "KeyRanges" for multiple tables. -func TableHandleRangesToKVRanges(sc *stmtctx.StatementContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range, fb *statistics.QueryFeedback) ([]kv.KeyRange, error) { +func TableHandleRangesToKVRanges(sc *stmtctx.StatementContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range, fb *statistics.QueryFeedback) (*kv.KeyRanges, error) { if !isCommonHandle { return tablesRangesToKVRanges(tid, ranges, fb), nil } @@ -387,14 +405,18 @@ func TableHandleRangesToKVRanges(sc *stmtctx.StatementContext, tid []int64, isCo // Note this function should not be exported, but currently // br refers to it, so have to keep it. func TableRangesToKVRanges(tid int64, ranges []*ranger.Range, fb *statistics.QueryFeedback) []kv.KeyRange { - return tablesRangesToKVRanges([]int64{tid}, ranges, fb) + if len(ranges) == 0 { + return []kv.KeyRange{} + } + return tablesRangesToKVRanges([]int64{tid}, ranges, fb).FirstPartitionRange() } // tablesRangesToKVRanges converts table ranges to "KeyRange". -func tablesRangesToKVRanges(tids []int64, ranges []*ranger.Range, fb *statistics.QueryFeedback) []kv.KeyRange { +func tablesRangesToKVRanges(tids []int64, ranges []*ranger.Range, fb *statistics.QueryFeedback) *kv.KeyRanges { if fb == nil || fb.Hist == nil { return tableRangesToKVRangesWithoutSplit(tids, ranges) } + // The following codes are deprecated since the feedback is deprecated. krs := make([]kv.KeyRange, 0, len(ranges)) feedbackRanges := make([]*ranger.Range, 0, len(ranges)) for _, ran := range ranges { @@ -420,20 +442,23 @@ func tablesRangesToKVRanges(tids []int64, ranges []*ranger.Range, fb *statistics } } fb.StoreRanges(feedbackRanges) - return krs + return kv.NewNonParitionedKeyRanges(krs) } -func tableRangesToKVRangesWithoutSplit(tids []int64, ranges []*ranger.Range) []kv.KeyRange { - krs := make([]kv.KeyRange, 0, len(ranges)*len(tids)) +func tableRangesToKVRangesWithoutSplit(tids []int64, ranges []*ranger.Range) *kv.KeyRanges { + krs := make([][]kv.KeyRange, len(tids)) + for i := range krs { + krs[i] = make([]kv.KeyRange, 0, len(ranges)) + } for _, ran := range ranges { low, high := encodeHandleKey(ran) - for _, tid := range tids { + for i, tid := range tids { startKey := tablecodec.EncodeRowKey(tid, low) endKey := tablecodec.EncodeRowKey(tid, high) - krs = append(krs, kv.KeyRange{StartKey: startKey, EndKey: endKey}) + krs[i] = append(krs[i], kv.KeyRange{StartKey: startKey, EndKey: endKey}) } } - return krs + return kv.NewPartitionedKeyRanges(krs) } func encodeHandleKey(ran *ranger.Range) ([]byte, []byte) { @@ -587,27 +612,33 @@ func PartitionHandlesToKVRanges(handles []kv.Handle) []kv.KeyRange { } // IndexRangesToKVRanges converts index ranges to "KeyRange". -func IndexRangesToKVRanges(sc *stmtctx.StatementContext, tid, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback) ([]kv.KeyRange, error) { +func IndexRangesToKVRanges(sc *stmtctx.StatementContext, tid, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback) (*kv.KeyRanges, error) { return IndexRangesToKVRangesWithInterruptSignal(sc, tid, idxID, ranges, fb, nil, nil) } // IndexRangesToKVRangesWithInterruptSignal converts index ranges to "KeyRange". // The process can be interrupted by set `interruptSignal` to true. -func IndexRangesToKVRangesWithInterruptSignal(sc *stmtctx.StatementContext, tid, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback, memTracker *memory.Tracker, interruptSignal *atomic.Value) ([]kv.KeyRange, error) { - return indexRangesToKVRangesForTablesWithInterruptSignal(sc, []int64{tid}, idxID, ranges, fb, memTracker, interruptSignal) +func IndexRangesToKVRangesWithInterruptSignal(sc *stmtctx.StatementContext, tid, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { + keyRanges, err := indexRangesToKVRangesForTablesWithInterruptSignal(sc, []int64{tid}, idxID, ranges, fb, memTracker, interruptSignal) + if err != nil { + return nil, err + } + err = keyRanges.SetToNonPartitioned() + return keyRanges, err } // IndexRangesToKVRangesForTables converts indexes ranges to "KeyRange". -func IndexRangesToKVRangesForTables(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback) ([]kv.KeyRange, error) { +func IndexRangesToKVRangesForTables(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback) (*kv.KeyRanges, error) { return indexRangesToKVRangesForTablesWithInterruptSignal(sc, tids, idxID, ranges, fb, nil, nil) } // IndexRangesToKVRangesForTablesWithInterruptSignal converts indexes ranges to "KeyRange". // The process can be interrupted by set `interruptSignal` to true. -func indexRangesToKVRangesForTablesWithInterruptSignal(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback, memTracker *memory.Tracker, interruptSignal *atomic.Value) ([]kv.KeyRange, error) { +func indexRangesToKVRangesForTablesWithInterruptSignal(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, fb *statistics.QueryFeedback, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { if fb == nil || fb.Hist == nil { return indexRangesToKVWithoutSplit(sc, tids, idxID, ranges, memTracker, interruptSignal) } + // The following code is non maintained since the feedback deprecated. feedbackRanges := make([]*ranger.Range, 0, len(ranges)) for _, ran := range ranges { low, high, err := EncodeIndexKey(sc, ran) @@ -642,11 +673,11 @@ func indexRangesToKVRangesForTablesWithInterruptSignal(sc *stmtctx.StatementCont } } fb.StoreRanges(feedbackRanges) - return krs, nil + return kv.NewNonParitionedKeyRanges(krs), nil } // CommonHandleRangesToKVRanges converts common handle ranges to "KeyRange". -func CommonHandleRangesToKVRanges(sc *stmtctx.StatementContext, tids []int64, ranges []*ranger.Range) ([]kv.KeyRange, error) { +func CommonHandleRangesToKVRanges(sc *stmtctx.StatementContext, tids []int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { rans := make([]*ranger.Range, 0, len(ranges)) for _, ran := range ranges { low, high, err := EncodeIndexKey(sc, ran) @@ -656,20 +687,23 @@ func CommonHandleRangesToKVRanges(sc *stmtctx.StatementContext, tids []int64, ra rans = append(rans, &ranger.Range{LowVal: []types.Datum{types.NewBytesDatum(low)}, HighVal: []types.Datum{types.NewBytesDatum(high)}, LowExclude: false, HighExclude: true, Collators: collate.GetBinaryCollatorSlice(1)}) } - krs := make([]kv.KeyRange, 0, len(rans)) + krs := make([][]kv.KeyRange, len(tids)) + for i := range krs { + krs[i] = make([]kv.KeyRange, 0, len(ranges)) + } for _, ran := range rans { low, high := ran.LowVal[0].GetBytes(), ran.HighVal[0].GetBytes() if ran.LowExclude { low = kv.Key(low).PrefixNext() } ran.LowVal[0].SetBytes(low) - for _, tid := range tids { + for i, tid := range tids { startKey := tablecodec.EncodeRowKey(tid, low) endKey := tablecodec.EncodeRowKey(tid, high) - krs = append(krs, kv.KeyRange{StartKey: startKey, EndKey: endKey}) + krs[i] = append(krs[i], kv.KeyRange{StartKey: startKey, EndKey: endKey}) } } - return krs, nil + return kv.NewPartitionedKeyRanges(krs), nil } // VerifyTxnScope verify whether the txnScope and visited physical table break the leader rule's dcLocation. @@ -691,8 +725,12 @@ func VerifyTxnScope(txnScope string, physicalTableID int64, is infoschema.InfoSc return true } -func indexRangesToKVWithoutSplit(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) ([]kv.KeyRange, error) { - krs := make([]kv.KeyRange, 0, len(ranges)) +func indexRangesToKVWithoutSplit(sc *stmtctx.StatementContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { + krs := make([][]kv.KeyRange, len(tids)) + for i := range krs { + krs[i] = make([]kv.KeyRange, 0, len(ranges)) + } + const checkSignalStep = 8 var estimatedMemUsage int64 // encodeIndexKey and EncodeIndexSeekKey is time-consuming, thus we need to @@ -705,13 +743,13 @@ func indexRangesToKVWithoutSplit(sc *stmtctx.StatementContext, tids []int64, idx if i == 0 { estimatedMemUsage += int64(cap(low) + cap(high)) } - for _, tid := range tids { + for j, tid := range tids { startKey := tablecodec.EncodeIndexSeekKey(tid, idxID, low) endKey := tablecodec.EncodeIndexSeekKey(tid, idxID, high) if i == 0 { estimatedMemUsage += int64(cap(startKey)) + int64(cap(endKey)) } - krs = append(krs, kv.KeyRange{StartKey: startKey, EndKey: endKey}) + krs[j] = append(krs[j], kv.KeyRange{StartKey: startKey, EndKey: endKey}) } if i%checkSignalStep == 0 { if i == 0 && memTracker != nil { @@ -719,11 +757,11 @@ func indexRangesToKVWithoutSplit(sc *stmtctx.StatementContext, tids []int64, idx memTracker.Consume(estimatedMemUsage) } if interruptSignal != nil && interruptSignal.Load().(bool) { - return nil, nil + return kv.NewPartitionedKeyRanges(nil), nil } } } - return krs, nil + return kv.NewPartitionedKeyRanges(krs), nil } // EncodeIndexKey gets encoded keys containing low and high diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go index 2ffde4a512c0d..fa55229e36fa5 100644 --- a/distsql/request_builder_test.go +++ b/distsql/request_builder_test.go @@ -192,8 +192,8 @@ func TestIndexRangesToKVRanges(t *testing.T) { actual, err := IndexRangesToKVRanges(new(stmtctx.StatementContext), 12, 15, ranges, nil) require.NoError(t, err) - for i := range actual { - require.Equal(t, expect[i], actual[i]) + for i := range actual.FirstPartitionRange() { + require.Equal(t, expect[i], actual.FirstPartitionRange()[i]) } } @@ -242,7 +242,7 @@ func TestRequestBuilder1(t *testing.T) { Tp: 103, StartTs: 0x0, Data: []uint8{0x18, 0x0, 0x20, 0x0, 0x40, 0x0, 0x5a, 0x0}, - KeyRanges: []kv.KeyRange{ + KeyRanges: kv.NewNonParitionedKeyRanges([]kv.KeyRange{ { StartKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, EndKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3}, @@ -263,7 +263,7 @@ func TestRequestBuilder1(t *testing.T) { StartKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x23}, EndKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x23}, }, - }, + }), Cacheable: true, KeepOrder: false, Desc: false, @@ -325,7 +325,7 @@ func TestRequestBuilder2(t *testing.T) { Tp: 103, StartTs: 0x0, Data: []uint8{0x18, 0x0, 0x20, 0x0, 0x40, 0x0, 0x5a, 0x0}, - KeyRanges: []kv.KeyRange{ + KeyRanges: kv.NewNonParitionedKeyRanges([]kv.KeyRange{ { StartKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x69, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x3, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, EndKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x69, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x3, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3}, @@ -346,7 +346,7 @@ func TestRequestBuilder2(t *testing.T) { StartKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x69, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x3, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x23}, EndKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x5f, 0x69, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x3, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x23}, }, - }, + }), Cacheable: true, KeepOrder: false, Desc: false, @@ -378,7 +378,7 @@ func TestRequestBuilder3(t *testing.T) { Tp: 103, StartTs: 0x0, Data: []uint8{0x18, 0x0, 0x20, 0x0, 0x40, 0x0, 0x5a, 0x0}, - KeyRanges: []kv.KeyRange{ + KeyRanges: kv.NewNonParitionedKeyRanges([]kv.KeyRange{ { StartKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, EndKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, @@ -395,7 +395,7 @@ func TestRequestBuilder3(t *testing.T) { StartKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x64}, EndKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x5f, 0x72, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x65}, }, - }, + }), Cacheable: true, KeepOrder: false, Desc: false, @@ -444,7 +444,7 @@ func TestRequestBuilder4(t *testing.T) { Tp: 103, StartTs: 0x0, Data: []uint8{0x18, 0x0, 0x20, 0x0, 0x40, 0x0, 0x5a, 0x0}, - KeyRanges: keyRanges, + KeyRanges: kv.NewNonParitionedKeyRanges(keyRanges), Cacheable: true, KeepOrder: false, Desc: false, @@ -491,7 +491,7 @@ func TestRequestBuilder5(t *testing.T) { Tp: 104, StartTs: 0x0, Data: []uint8{0x8, 0x0, 0x18, 0x0, 0x20, 0x0}, - KeyRanges: keyRanges, + KeyRanges: kv.NewNonParitionedKeyRanges(keyRanges), KeepOrder: true, Desc: false, Concurrency: 15, @@ -520,7 +520,7 @@ func TestRequestBuilder6(t *testing.T) { Tp: 105, StartTs: 0x0, Data: []uint8{0x10, 0x0, 0x18, 0x0}, - KeyRanges: keyRanges, + KeyRanges: kv.NewNonParitionedKeyRanges(keyRanges), KeepOrder: false, Desc: false, Concurrency: concurrency, @@ -557,6 +557,7 @@ func TestRequestBuilder7(t *testing.T) { Tp: 0, StartTs: 0x0, KeepOrder: false, + KeyRanges: kv.NewNonParitionedKeyRanges(nil), Desc: false, Concurrency: concurrency, IsolationLevel: 0, @@ -583,6 +584,7 @@ func TestRequestBuilder8(t *testing.T) { Tp: 0, StartTs: 0x0, Data: []uint8(nil), + KeyRanges: kv.NewNonParitionedKeyRanges(nil), Concurrency: variable.DefDistSQLScanConcurrency, IsolationLevel: 0, Priority: 0, @@ -635,8 +637,8 @@ func TestIndexRangesToKVRangesWithFbs(t *testing.T) { EndKey: kv.Key{0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5f, 0x69, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5}, }, } - for i := 0; i < len(actual); i++ { - require.Equal(t, expect[i], actual[i]) + for i := 0; i < len(actual.FirstPartitionRange()); i++ { + require.Equal(t, expect[i], actual.FirstPartitionRange()[i]) } } diff --git a/executor/admin.go b/executor/admin.go index 1a0f5579281cc..a0484ce957b30 100644 --- a/executor/admin.go +++ b/executor/admin.go @@ -265,10 +265,11 @@ func (e *RecoverIndexExec) buildTableScan(ctx context.Context, txn kv.Transactio return nil, err } var builder distsql.RequestBuilder - builder.KeyRanges, err = buildRecoverIndexKeyRanges(e.ctx.GetSessionVars().StmtCtx, e.physicalID, startHandle) + keyRanges, err := buildRecoverIndexKeyRanges(e.ctx.GetSessionVars().StmtCtx, e.physicalID, startHandle) if err != nil { return nil, err } + builder.KeyRanges = kv.NewNonParitionedKeyRanges(keyRanges) kvReq, err := builder. SetDAGRequest(dagPB). SetStartTS(txn.StartTS()). @@ -737,7 +738,16 @@ func (e *CleanupIndexExec) buildIndexScan(ctx context.Context, txn kv.Transactio sc := e.ctx.GetSessionVars().StmtCtx var builder distsql.RequestBuilder ranges := ranger.FullRange() - kvReq, err := builder.SetIndexRanges(sc, e.physicalID, e.index.Meta().ID, ranges). + keyRanges, err := distsql.IndexRangesToKVRanges(sc, e.physicalID, e.index.Meta().ID, ranges, nil) + if err != nil { + return nil, err + } + err = keyRanges.SetToNonPartitioned() + if err != nil { + return nil, err + } + keyRanges.FirstPartitionRange()[0].StartKey = kv.Key(e.lastIdxKey).PrefixNext() + kvReq, err := builder.SetWrappedKeyRanges(keyRanges). SetDAGRequest(dagPB). SetStartTS(txn.StartTS()). SetKeepOrder(true). @@ -748,7 +758,6 @@ func (e *CleanupIndexExec) buildIndexScan(ctx context.Context, txn kv.Transactio return nil, err } - kvReq.KeyRanges[0].StartKey = kv.Key(e.lastIdxKey).PrefixNext() kvReq.Concurrency = 1 result, err := distsql.Select(ctx, e.ctx, kvReq, e.getIdxColTypes(), statistics.NewQueryFeedback(0, nil, 0, false)) if err != nil { diff --git a/executor/builder.go b/executor/builder.go index 3d015849aa5ef..01bf8496fe77b 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4262,32 +4262,37 @@ type kvRangeBuilderFromRangeAndPartition struct { } func (h kvRangeBuilderFromRangeAndPartition) buildKeyRangeSeparately(ranges []*ranger.Range) ([]int64, [][]kv.KeyRange, error) { - ret := make([][]kv.KeyRange, 0, len(h.partitions)) + ret := make([][]kv.KeyRange, len(h.partitions)) pids := make([]int64, 0, len(h.partitions)) - for _, p := range h.partitions { + for i, p := range h.partitions { pid := p.GetPhysicalID() + pids = append(pids, pid) meta := p.Meta() + if len(ranges) == 0 { + continue + } kvRange, err := distsql.TableHandleRangesToKVRanges(h.sctx.GetSessionVars().StmtCtx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges, nil) if err != nil { return nil, nil, err } - pids = append(pids, pid) - ret = append(ret, kvRange) + ret[i] = kvRange.AppendSelfTo(ret[i]) } return pids, ret, nil } -func (h kvRangeBuilderFromRangeAndPartition) buildKeyRange(ranges []*ranger.Range) ([]kv.KeyRange, error) { - //nolint: prealloc - var ret []kv.KeyRange - for _, p := range h.partitions { +func (h kvRangeBuilderFromRangeAndPartition) buildKeyRange(ranges []*ranger.Range) ([][]kv.KeyRange, error) { + ret := make([][]kv.KeyRange, len(h.partitions)) + if len(ranges) == 0 { + return ret, nil + } + for i, p := range h.partitions { pid := p.GetPhysicalID() meta := p.Meta() kvRange, err := distsql.TableHandleRangesToKVRanges(h.sctx.GetSessionVars().StmtCtx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges, nil) if err != nil { return nil, err } - ret = append(ret, kvRange...) + ret[i] = kvRange.AppendSelfTo(ret[i]) } return ret, nil } @@ -4334,7 +4339,7 @@ func (builder *dataReaderBuilder) buildTableReaderBase(ctx context.Context, e *T if err != nil { return nil, err } - e.kvRanges = append(e.kvRanges, kvReq.KeyRanges...) + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) e.resultHandler = &tableResultHandler{} result, err := builder.SelectResult(ctx, builder.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans), e.id) if err != nil { @@ -4357,6 +4362,8 @@ func (builder *dataReaderBuilder) buildTableReaderFromHandles(ctx context.Contex } else { b.SetTableHandles(getPhysicalTableID(e.table), handles) } + } else { + b.SetKeyRanges(nil) } return builder.buildTableReaderBase(ctx, e, b) } @@ -4545,6 +4552,9 @@ func buildRangesForIndexJoin(ctx sessionctx.Context, lookUpContents []*indexJoin func buildKvRangesForIndexJoin(ctx sessionctx.Context, tableID, indexID int64, lookUpContents []*indexJoinLookUpContent, ranges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memTracker *memory.Tracker, interruptSignal *atomic.Value) (_ []kv.KeyRange, err error) { kvRanges := make([]kv.KeyRange, 0, len(ranges)*len(lookUpContents)) + if len(ranges) == 0 { + return []kv.KeyRange{}, nil + } lastPos := len(ranges[0].LowVal) - 1 sc := ctx.GetSessionVars().StmtCtx tmpDatumRanges := make([]*ranger.Range, 0, len(lookUpContents)) @@ -4557,7 +4567,7 @@ func buildKvRangesForIndexJoin(ctx sessionctx.Context, tableID, indexID int64, l } if cwc == nil { // Index id is -1 means it's a common handle. - var tmpKvRanges []kv.KeyRange + var tmpKvRanges *kv.KeyRanges var err error if indexID == -1 { tmpKvRanges, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{tableID}, ranges) @@ -4567,7 +4577,7 @@ func buildKvRangesForIndexJoin(ctx sessionctx.Context, tableID, indexID int64, l if err != nil { return nil, err } - kvRanges = append(kvRanges, tmpKvRanges...) + kvRanges = tmpKvRanges.AppendSelfTo(kvRanges) continue } nextColRanges, err := cwc.BuildRangesByRow(ctx, content.row) @@ -4604,9 +4614,11 @@ func buildKvRangesForIndexJoin(ctx sessionctx.Context, tableID, indexID int64, l } // Index id is -1 means it's a common handle. if indexID == -1 { - return distsql.CommonHandleRangesToKVRanges(ctx.GetSessionVars().StmtCtx, []int64{tableID}, tmpDatumRanges) + tmpKeyRanges, err := distsql.CommonHandleRangesToKVRanges(ctx.GetSessionVars().StmtCtx, []int64{tableID}, tmpDatumRanges) + return tmpKeyRanges.FirstPartitionRange(), err } - return distsql.IndexRangesToKVRangesWithInterruptSignal(ctx.GetSessionVars().StmtCtx, tableID, indexID, tmpDatumRanges, nil, memTracker, interruptSignal) + tmpKeyRanges, err := distsql.IndexRangesToKVRangesWithInterruptSignal(ctx.GetSessionVars().StmtCtx, tableID, indexID, tmpDatumRanges, nil, memTracker, interruptSignal) + return tmpKeyRanges.FirstPartitionRange(), err } func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) Executor { diff --git a/executor/distsql.go b/executor/distsql.go index 0cef7e66d441e..966a4a09ce357 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -243,11 +243,18 @@ func (e *IndexReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error return err } +// TODO: cleanup this method. func (e *IndexReaderExecutor) buildKeyRanges(sc *stmtctx.StatementContext, ranges []*ranger.Range, physicalID int64) ([]kv.KeyRange, error) { + var ( + rRanges *kv.KeyRanges + err error + ) if e.index.ID == -1 { - return distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, ranges) + rRanges, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, ranges) + } else { + rRanges, err = distsql.IndexRangesToKVRanges(sc, physicalID, e.index.ID, ranges, e.feedback) } - return distsql.IndexRangesToKVRanges(sc, physicalID, e.index.ID, ranges, e.feedback) + return rRanges.FirstPartitionRange(), err } // Open implements the Executor Open interface. @@ -458,9 +465,6 @@ func (e *IndexLookUpExecutor) Open(ctx context.Context) error { func (e *IndexLookUpExecutor) buildTableKeyRanges() (err error) { sc := e.ctx.GetSessionVars().StmtCtx if e.partitionTableMode { - if e.keepOrder { // this case should be prevented by the optimizer - return errors.New("invalid execution plan: cannot keep order when accessing a partition table by IndexLookUpReader") - } e.feedback.Invalidate() // feedback for partition tables is not ready e.partitionKVRanges = make([][]kv.KeyRange, 0, len(e.prunedPartitions)) for _, p := range e.prunedPartitions { @@ -472,7 +476,7 @@ func (e *IndexLookUpExecutor) buildTableKeyRanges() (err error) { if e.partitionRangeMap != nil && e.partitionRangeMap[physicalID] != nil { ranges = e.partitionRangeMap[physicalID] } - var kvRange []kv.KeyRange + var kvRange *kv.KeyRanges if e.index.ID == -1 { kvRange, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, ranges) } else { @@ -481,15 +485,17 @@ func (e *IndexLookUpExecutor) buildTableKeyRanges() (err error) { if err != nil { return err } - e.partitionKVRanges = append(e.partitionKVRanges, kvRange) + e.partitionKVRanges = append(e.partitionKVRanges, kvRange.FirstPartitionRange()) } } else { physicalID := getPhysicalTableID(e.table) + var kvRanges *kv.KeyRanges if e.index.ID == -1 { - e.kvRanges, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, e.ranges) + kvRanges, err = distsql.CommonHandleRangesToKVRanges(sc, []int64{physicalID}, e.ranges) } else { - e.kvRanges, err = distsql.IndexRangesToKVRanges(sc, physicalID, e.index.ID, e.ranges, e.feedback) + kvRanges, err = distsql.IndexRangesToKVRanges(sc, physicalID, e.index.ID, e.ranges, e.feedback) } + e.kvRanges = kvRanges.FirstPartitionRange() } return err } diff --git a/executor/index_merge_reader.go b/executor/index_merge_reader.go index b487089cfa30c..07cf36acca742 100644 --- a/executor/index_merge_reader.go +++ b/executor/index_merge_reader.go @@ -204,7 +204,7 @@ func (e *IndexMergeReaderExecutor) buildKeyRangesForTable(tbl table.Table) (rang if err != nil { return nil, err } - keyRanges := append(firstKeyRanges, secondKeyRanges...) + keyRanges := append(firstKeyRanges.FirstPartitionRange(), secondKeyRanges.FirstPartitionRange()...) ranges = append(ranges, keyRanges) continue } @@ -212,7 +212,7 @@ func (e *IndexMergeReaderExecutor) buildKeyRangesForTable(tbl table.Table) (rang if err != nil { return nil, err } - ranges = append(ranges, keyRange) + ranges = append(ranges, keyRange.FirstPartitionRange()) } return ranges, nil } diff --git a/executor/partition_table_test.go b/executor/partition_table_test.go index 50bb68a7b5235..b2ba37634a8a4 100644 --- a/executor/partition_table_test.go +++ b/executor/partition_table_test.go @@ -84,7 +84,7 @@ partition p2 values less than (10))`) // Table reader: one partition tk.MustQuery("select * from pt where c > 8").Check(testkit.Rows("9 9")) // Table reader: more than one partition - tk.MustQuery("select * from pt where c < 2 or c >= 9").Check(testkit.Rows("0 0", "9 9")) + tk.MustQuery("select * from pt where c < 2 or c >= 9").Sort().Check(testkit.Rows("0 0", "9 9")) // Index reader tk.MustQuery("select c from pt").Sort().Check(testkit.Rows("0", "2", "4", "6", "7", "9", "")) @@ -96,7 +96,7 @@ partition p2 values less than (10))`) tk.MustQuery("select /*+ use_index(pt, i_id) */ * from pt").Sort().Check(testkit.Rows("0 0", "2 2", "4 4", "6 6", "7 7", "9 9", " ")) tk.MustQuery("select /*+ use_index(pt, i_id) */ * from pt where id < 4 and c > 10").Check(testkit.Rows()) tk.MustQuery("select /*+ use_index(pt, i_id) */ * from pt where id < 10 and c > 8").Check(testkit.Rows("9 9")) - tk.MustQuery("select /*+ use_index(pt, i_id) */ * from pt where id < 10 and c < 2 or c >= 9").Check(testkit.Rows("0 0", "9 9")) + tk.MustQuery("select /*+ use_index(pt, i_id) */ * from pt where id < 10 and c < 2 or c >= 9").Sort().Check(testkit.Rows("0 0", "9 9")) // Index Merge tk.MustExec("set @@tidb_enable_index_merge = 1") @@ -377,14 +377,67 @@ func TestOrderByandLimit(t *testing.T) { // regular table tk.MustExec("create table tregular(a int, b int, index idx_a(a))") + // range partition table with int pk + tk.MustExec(`create table trange_intpk(a int primary key, b int) partition by range(a) ( + partition p0 values less than(300), + partition p1 values less than (500), + partition p2 values less than(1100));`) + + // hash partition table with int pk + tk.MustExec("create table thash_intpk(a int primary key, b int) partition by hash(a) partitions 4;") + + // regular table with int pk + tk.MustExec("create table tregular_intpk(a int primary key, b int)") + + // range partition table with clustered index + tk.MustExec(`create table trange_clustered(a int, b int, primary key(a, b) clustered) partition by range(a) ( + partition p0 values less than(300), + partition p1 values less than (500), + partition p2 values less than(1100));`) + + // hash partition table with clustered index + tk.MustExec("create table thash_clustered(a int, b int, primary key(a, b) clustered) partition by hash(a) partitions 4;") + + // regular table with clustered index + tk.MustExec("create table tregular_clustered(a int, b int, primary key(a, b) clustered)") + // generate some random data to be inserted vals := make([]string, 0, 2000) for i := 0; i < 2000; i++ { vals = append(vals, fmt.Sprintf("(%v, %v)", rand.Intn(1100), rand.Intn(2000))) } + + dedupValsA := make([]string, 0, 2000) + dedupMapA := make(map[int]struct{}, 2000) + for i := 0; i < 2000; i++ { + valA := rand.Intn(1100) + if _, ok := dedupMapA[valA]; ok { + continue + } + dedupValsA = append(dedupValsA, fmt.Sprintf("(%v, %v)", valA, rand.Intn(2000))) + dedupMapA[valA] = struct{}{} + } + + dedupValsAB := make([]string, 0, 2000) + dedupMapAB := make(map[string]struct{}, 2000) + for i := 0; i < 2000; i++ { + val := fmt.Sprintf("(%v, %v)", rand.Intn(1100), rand.Intn(2000)) + if _, ok := dedupMapAB[val]; ok { + continue + } + dedupValsAB = append(dedupValsAB, val) + dedupMapAB[val] = struct{}{} + } + tk.MustExec("insert into trange values " + strings.Join(vals, ",")) tk.MustExec("insert into thash values " + strings.Join(vals, ",")) tk.MustExec("insert into tregular values " + strings.Join(vals, ",")) + tk.MustExec("insert into trange_intpk values " + strings.Join(dedupValsA, ",")) + tk.MustExec("insert into thash_intpk values " + strings.Join(dedupValsA, ",")) + tk.MustExec("insert into tregular_intpk values " + strings.Join(dedupValsA, ",")) + tk.MustExec("insert into trange_clustered values " + strings.Join(dedupValsAB, ",")) + tk.MustExec("insert into thash_clustered values " + strings.Join(dedupValsAB, ",")) + tk.MustExec("insert into tregular_clustered values " + strings.Join(dedupValsAB, ",")) // test indexLookUp for i := 0; i < 100; i++ { @@ -398,6 +451,29 @@ func TestOrderByandLimit(t *testing.T) { tk.MustQuery(queryPartition).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) } + // test indexLookUp with order property pushed down. + for i := 0; i < 100; i++ { + // explain select * from t where a > {y} use index(idx_a) order by a limit {x}; // check if IndexLookUp is used + // select * from t where a > {y} use index(idx_a) order by a limit {x}; // it can return the correct result + x := rand.Intn(1099) + y := rand.Intn(2000) + 1 + // Since we only use order by a not order by a, b, the result is not stable when we read both a and b. + // We cut the max element so that the result can be stable. + maxEle := tk.MustQuery(fmt.Sprintf("select ifnull(max(a), 1100) from (select * from tregular use index(idx_a) where a > %v order by a limit %v) t", x, y)).Rows()[0][0] + queryRangePartitionWithLimitHint := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from trange use index(idx_a) where a > %v and a < greatest(%v+1, %v) order by a limit %v", x, x+1, maxEle, y) + queryHashPartitionWithLimitHint := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from thash use index(idx_a) where a > %v and a < greatest(%v+1, %v) order by a limit %v", x, x+1, maxEle, y) + queryRegular := fmt.Sprintf("select * from tregular use index(idx_a) where a > %v and a < greatest(%v+1, %v) order by a limit %v;", x, x+1, maxEle, y) + require.True(t, tk.HasPlan(queryRangePartitionWithLimitHint, "Limit")) + require.True(t, tk.HasPlan(queryRangePartitionWithLimitHint, "IndexLookUp")) + require.True(t, tk.HasPlan(queryHashPartitionWithLimitHint, "Limit")) + require.True(t, tk.HasPlan(queryHashPartitionWithLimitHint, "IndexLookUp")) + require.True(t, tk.HasPlan(queryRangePartitionWithLimitHint, "TopN")) // but not fully pushed + require.True(t, tk.HasPlan(queryHashPartitionWithLimitHint, "TopN")) + regularResult := tk.MustQuery(queryRegular).Sort().Rows() + tk.MustQuery(queryRangePartitionWithLimitHint).Sort().Check(regularResult) + tk.MustQuery(queryHashPartitionWithLimitHint).Sort().Check(regularResult) + } + // test tableReader for i := 0; i < 100; i++ { // explain select * from t where a > {y} ignore index(idx_a) order by a limit {x}; // check if IndexLookUp is used @@ -410,6 +486,51 @@ func TestOrderByandLimit(t *testing.T) { tk.MustQuery(queryPartition).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) } + // test tableReader with order property pushed down. + for i := 0; i < 100; i++ { + // explain select * from t where a > {y} ignore index(idx_a) order by a limit {x}; // check if IndexLookUp is used + // select * from t where a > {y} ignore index(idx_a) order by a limit {x}; // it can return the correct result + x := rand.Intn(1099) + y := rand.Intn(2000) + 1 + queryRangePartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from trange ignore index(idx_a) where a > %v order by a, b limit %v;", x, y) + queryHashPartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from thash ignore index(idx_a) where a > %v order by a, b limit %v;", x, y) + queryRegular := fmt.Sprintf("select * from tregular ignore index(idx_a) where a > %v order by a, b limit %v;", x, y) + require.True(t, tk.HasPlan(queryRangePartition, "TableReader")) // check if tableReader is used + require.True(t, tk.HasPlan(queryHashPartition, "TableReader")) + require.False(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is not pushed + require.False(t, tk.HasPlan(queryHashPartition, "Limit")) + regularResult := tk.MustQuery(queryRegular).Sort().Rows() + tk.MustQuery(queryRangePartition).Sort().Check(regularResult) + tk.MustQuery(queryHashPartition).Sort().Check(regularResult) + + // test int pk + // To be simplified, we only read column a. + queryRangePartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from trange_intpk use index(primary) where a > %v order by a limit %v", x, y) + queryHashPartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from thash_intpk use index(primary) where a > %v order by a limit %v", x, y) + queryRegular = fmt.Sprintf("select a from tregular_intpk where a > %v order by a limit %v", x, y) + require.True(t, tk.HasPlan(queryRangePartition, "TableReader")) + require.True(t, tk.HasPlan(queryHashPartition, "TableReader")) + require.True(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is not pushed + require.True(t, tk.HasPlan(queryHashPartition, "Limit")) + regularResult = tk.MustQuery(queryRegular).Rows() + tk.MustQuery(queryRangePartition).Check(regularResult) + tk.MustQuery(queryHashPartition).Check(regularResult) + + // test clustered index + queryRangePartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from trange_clustered use index(primary) where a > %v order by a, b limit %v;", x, y) + queryHashPartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from thash_clustered use index(primary) where a > %v order by a, b limit %v;", x, y) + queryRegular = fmt.Sprintf("select * from tregular_clustered where a > %v order by a, b limit %v;", x, y) + require.True(t, tk.HasPlan(queryRangePartition, "TableReader")) // check if tableReader is used + require.True(t, tk.HasPlan(queryHashPartition, "TableReader")) + require.True(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is pushed + require.True(t, tk.HasPlan(queryHashPartition, "Limit")) + require.True(t, tk.HasPlan(queryRangePartition, "TopN")) // but not fully pushed + require.True(t, tk.HasPlan(queryHashPartition, "TopN")) + regularResult = tk.MustQuery(queryRegular).Rows() + tk.MustQuery(queryRangePartition).Check(regularResult) + tk.MustQuery(queryHashPartition).Check(regularResult) + } + // test indexReader for i := 0; i < 100; i++ { // explain select a from t where a > {y} use index(idx_a) order by a limit {x}; // check if IndexLookUp is used @@ -422,6 +543,24 @@ func TestOrderByandLimit(t *testing.T) { tk.MustQuery(queryPartition).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows()) } + // test indexReader with order property pushed down. + for i := 0; i < 100; i++ { + // explain select a from t where a > {y} use index(idx_a) order by a limit {x}; // check if IndexLookUp is used + // select a from t where a > {y} use index(idx_a) order by a limit {x}; // it can return the correct result + x := rand.Intn(1099) + y := rand.Intn(2000) + 1 + queryRangePartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from trange use index(idx_a) where a > %v order by a limit %v;", x, y) + queryHashPartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from trange use index(idx_a) where a > %v order by a limit %v;", x, y) + queryRegular := fmt.Sprintf("select a from tregular use index(idx_a) where a > %v order by a limit %v;", x, y) + require.True(t, tk.HasPlan(queryRangePartition, "IndexReader")) // check if indexReader is used + require.True(t, tk.HasPlan(queryHashPartition, "IndexReader")) + require.True(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is pushed + require.True(t, tk.HasPlan(queryHashPartition, "Limit")) + regularResult := tk.MustQuery(queryRegular).Sort().Rows() + tk.MustQuery(queryRangePartition).Sort().Check(regularResult) + tk.MustQuery(queryHashPartition).Sort().Check(regularResult) + } + // test indexMerge for i := 0; i < 100; i++ { // explain select /*+ use_index_merge(t) */ * from t where a > 2 or b < 5 order by a limit {x}; // check if IndexMerge is used @@ -2834,7 +2973,7 @@ partition p1 values less than (7), partition p2 values less than (10))`) tk.MustExec("alter table p add unique idx(id)") tk.MustExec("insert into p values (1,3), (3,4), (5,6), (7,9)") - tk.MustQuery("select id from p use index (idx)").Check(testkit.Rows("1", "3", "5", "7")) + tk.MustQuery("select id from p use index (idx) order by id").Check(testkit.Rows("1", "3", "5", "7")) } func TestGlobalIndexDoubleRead(t *testing.T) { diff --git a/executor/table_reader.go b/executor/table_reader.go index ce5a10b125754..984212dcf7328 100644 --- a/executor/table_reader.go +++ b/executor/table_reader.go @@ -61,7 +61,7 @@ func (sr selectResultHook) SelectResult(ctx context.Context, sctx sessionctx.Con } type kvRangeBuilder interface { - buildKeyRange(ranges []*ranger.Range) ([]kv.KeyRange, error) + buildKeyRange(ranges []*ranger.Range) ([][]kv.KeyRange, error) buildKeyRangeSeparately(ranges []*ranger.Range) ([]int64, [][]kv.KeyRange, error) } @@ -205,13 +205,13 @@ func (e *TableReaderExecutor) Open(ctx context.Context) error { if err != nil { return err } - e.kvRanges = append(e.kvRanges, kvReq.KeyRanges...) + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) if len(secondPartRanges) != 0 { kvReq, err = e.buildKVReq(ctx, secondPartRanges) if err != nil { return err } - e.kvRanges = append(e.kvRanges, kvReq.KeyRanges...) + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) } return nil } @@ -314,10 +314,10 @@ func (e *TableReaderExecutor) buildResp(ctx context.Context, ranges []*ranger.Ra if err != nil { return nil, err } - slices.SortFunc(kvReq.KeyRanges, func(i, j kv.KeyRange) bool { + kvReq.KeyRanges.SortByFunc(func(i, j kv.KeyRange) bool { return bytes.Compare(i.StartKey, j.StartKey) < 0 }) - e.kvRanges = append(e.kvRanges, kvReq.KeyRanges...) + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) result, err := e.SelectResult(ctx, e.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans), e.id) if err != nil { @@ -409,7 +409,7 @@ func (e *TableReaderExecutor) buildKVReq(ctx context.Context, ranges []*ranger.R if err != nil { return nil, err } - reqBuilder = builder.SetKeyRanges(kvRange) + reqBuilder = builder.SetPartitionKeyRanges(kvRange) } else { reqBuilder = builder.SetHandleRanges(e.ctx.GetSessionVars().StmtCtx, getPhysicalTableID(e.table), e.table.Meta() != nil && e.table.Meta().IsCommonHandle, ranges, e.feedback) } diff --git a/kv/BUILD.bazel b/kv/BUILD.bazel index 32dd9f1474179..992d99d382e42 100644 --- a/kv/BUILD.bazel +++ b/kv/BUILD.bazel @@ -48,6 +48,7 @@ go_library( "@com_github_tikv_client_go_v2//tikvrpc", "@com_github_tikv_client_go_v2//util", "@com_github_tikv_pd_client//:client", + "@org_golang_x_exp//slices", "@org_uber_go_zap//:zap", ], ) diff --git a/kv/kv.go b/kv/kv.go index 72e8111f0343d..8263746093a5c 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -15,6 +15,7 @@ package kv import ( + "bytes" "context" "crypto/tls" "time" @@ -33,6 +34,7 @@ import ( "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/util" pd "github.com/tikv/pd/client" + "golang.org/x/exp/slices" ) // UnCommitIndexKVFlag uses to indicate the index key/value is no need to commit. @@ -335,13 +337,148 @@ func (t StoreType) Name() string { return "unspecified" } +// KeyRanges wrap the ranges for partitioned table cases. +// We might send ranges from different in the one request. +type KeyRanges struct { + ranges [][]KeyRange + + isPartitioned bool +} + +// NewPartitionedKeyRanges constructs a new RequestRange for partitioned table. +func NewPartitionedKeyRanges(ranges [][]KeyRange) *KeyRanges { + return &KeyRanges{ + ranges: ranges, + isPartitioned: true, + } +} + +// NewNonParitionedKeyRanges constructs a new RequestRange for a non partitioned table. +func NewNonParitionedKeyRanges(ranges []KeyRange) *KeyRanges { + return &KeyRanges{ + ranges: [][]KeyRange{ranges}, + isPartitioned: false, + } +} + +// FirstPartitionRange returns the the result of first range. +// We may use some func to generate ranges for both partitioned table and non partitioned table. +// This method provides a way to fallback to non-partitioned ranges. +func (rr *KeyRanges) FirstPartitionRange() []KeyRange { + if len(rr.ranges) == 0 { + return []KeyRange{} + } + return rr.ranges[0] +} + +// SetToNonPartitioned set the status to non-partitioned. +func (rr *KeyRanges) SetToNonPartitioned() error { + if len(rr.ranges) > 1 { + return errors.Errorf("you want to change the partitioned ranges to non-partitioned ranges") + } + rr.isPartitioned = false + return nil +} + +// AppendSelfTo appends itself to another slice. +func (rr *KeyRanges) AppendSelfTo(ranges []KeyRange) []KeyRange { + for _, r := range rr.ranges { + ranges = append(ranges, r...) + } + return ranges +} + +// SortByFunc sorts each partition's ranges. +// Since the ranges are sorted in most cases, we check it first. +func (rr *KeyRanges) SortByFunc(sortFunc func(i, j KeyRange) bool) { + if !slices.IsSortedFunc(rr.ranges, func(i, j []KeyRange) bool { + // A simple short-circuit since the empty range actually won't make anything wrong. + if len(i) == 0 || len(j) == 0 { + return true + } + return sortFunc(i[0], j[0]) + }) { + slices.SortFunc(rr.ranges, func(i, j []KeyRange) bool { + if len(i) == 0 { + return true + } + if len(j) == 0 { + return false + } + return sortFunc(i[0], j[0]) + }) + } + for i := range rr.ranges { + if !slices.IsSortedFunc(rr.ranges[i], sortFunc) { + slices.SortFunc(rr.ranges[i], sortFunc) + } + } +} + +// ForEachPartitionWithErr runs the func for each partition with an error check. +func (rr *KeyRanges) ForEachPartitionWithErr(theFunc func([]KeyRange) error) (err error) { + for i := range rr.ranges { + err = theFunc(rr.ranges[i]) + if err != nil { + return err + } + } + return nil +} + +// ForEachPartition runs the func for each partition without error check. +func (rr *KeyRanges) ForEachPartition(theFunc func([]KeyRange)) { + for i := range rr.ranges { + theFunc(rr.ranges[i]) + } +} + +// PartitionNum returns how many partition is involved in the ranges. +func (rr *KeyRanges) PartitionNum() int { + return len(rr.ranges) +} + +// IsFullySorted checks whether the ranges are sorted inside partition and each partition is also sorated. +func (rr *KeyRanges) IsFullySorted() bool { + sortedByPartition := slices.IsSortedFunc(rr.ranges, func(i, j []KeyRange) bool { + // A simple short-circuit since the empty range actually won't make anything wrong. + if len(i) == 0 || len(j) == 0 { + return true + } + return bytes.Compare(i[0].StartKey, j[0].StartKey) < 0 + }) + if !sortedByPartition { + return false + } + for _, ranges := range rr.ranges { + if !slices.IsSortedFunc(ranges, func(i, j KeyRange) bool { + return bytes.Compare(i.StartKey, j.StartKey) < 0 + }) { + return false + } + } + return true +} + +// TotalRangeNum returns how many ranges there are. +func (rr *KeyRanges) TotalRangeNum() int { + ret := 0 + for _, r := range rr.ranges { + ret += len(r) + } + return ret +} + // Request represents a kv request. type Request struct { // Tp is the request type. - Tp int64 - StartTs uint64 - Data []byte - KeyRanges []KeyRange + Tp int64 + StartTs uint64 + Data []byte + + // KeyRanges makes sure that the request is sent first by partition then by region. + // When the table is small, it's possible that multiple partitions are in the same region. + KeyRanges *KeyRanges // For PartitionTableScan used by tiflash. PartitionIDAndRanges []PartitionIDAndRanges diff --git a/kv/option.go b/kv/option.go index ee5354141cd7b..a0e658f45aade 100644 --- a/kv/option.go +++ b/kv/option.go @@ -167,4 +167,6 @@ const ( InternalTxnBR = InternalTxnTools // InternalTxnTrace handles the trace statement. InternalTxnTrace = "Trace" + // InternalTxnTTL is the type of TTL usage + InternalTxnTTL = "TTL" ) diff --git a/meta/autoid/autoid_service.go b/meta/autoid/autoid_service.go index 6133dfdfc3cb2..2942f3281b769 100644 --- a/meta/autoid/autoid_service.go +++ b/meta/autoid/autoid_service.go @@ -152,7 +152,11 @@ func (sp *singlePointAlloc) resetConn() { // Close grpc.ClientConn to release resource. if grpcConn != nil { err := grpcConn.Close() - logutil.BgLogger().Info("[autoid client] AllocAutoID grpc error, reconnect", zap.Error(err)) + if err != nil { + logutil.BgLogger().Warn("[autoid client] AllocAutoID grpc error, reconnect", zap.Error(err)) + } else { + logutil.BgLogger().Info("[autoid client] AllocAutoID grpc error, reconnect") + } } } diff --git a/planner/cascades/testdata/integration_suite_in.json b/planner/cascades/testdata/integration_suite_in.json index 569cb12860ac3..5533e6c672fcb 100644 --- a/planner/cascades/testdata/integration_suite_in.json +++ b/planner/cascades/testdata/integration_suite_in.json @@ -142,7 +142,7 @@ { "name": "TestCascadePlannerHashedPartTable", "cases": [ - "select * from pt1" + "select * from pt1 order by a" ] }, { diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index e8d98a41ec557..262a825256e41 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -1198,17 +1198,18 @@ "Name": "TestCascadePlannerHashedPartTable", "Cases": [ { - "SQL": "select * from pt1", + "SQL": "select * from pt1 order by a", "Plan": [ - "TableReader_5 10000.00 root partition:all data:TableFullScan_6", - "└─TableFullScan_6 10000.00 cop[tikv] table:pt1 keep order:false, stats:pseudo" + "Sort_11 10000.00 root test.pt1.a", + "└─TableReader_9 10000.00 root partition:all data:TableFullScan_10", + " └─TableFullScan_10 10000.00 cop[tikv] table:pt1 keep order:false, stats:pseudo" ], "Result": [ - "4 40", "1 10", - "5 50", "2 20", - "3 30" + "3 30", + "4 40", + "5 50" ] } ] diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index dca8a704994d3..8596746822b23 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -2289,6 +2289,7 @@ func (ds *DataSource) getOriginalPhysicalIndexScan(prop *property.PhysicalProper physicalTableID: ds.physicalTableID, tblColHists: ds.TblColHists, pkIsHandleCol: ds.getPKIsHandleCol(), + constColsByCond: path.ConstCols, prop: prop, }.Init(ds.ctx, ds.blockOffset) statsTbl := ds.statisticTable diff --git a/planner/core/fragment.go b/planner/core/fragment.go index 5dfa93186826f..c6aec17f21e6d 100644 --- a/planner/core/fragment.go +++ b/planner/core/fragment.go @@ -406,7 +406,7 @@ func (e *mppTaskGenerator) constructMPPBuildTaskReqForPartitionedTable(ts *Physi return nil, nil, errors.Trace(err) } partitionIDAndRanges[i].ID = pid - partitionIDAndRanges[i].KeyRanges = kvRanges + partitionIDAndRanges[i].KeyRanges = kvRanges.FirstPartitionRange() allPartitionsIDs[i] = pid } return &kv.MPPBuildTasksRequest{PartitionIDAndRanges: partitionIDAndRanges}, allPartitionsIDs, nil @@ -417,5 +417,5 @@ func (e *mppTaskGenerator) constructMPPBuildTaskForNonPartitionTable(ts *Physica if err != nil { return nil, errors.Trace(err) } - return &kv.MPPBuildTasksRequest{KeyRanges: kvRanges}, nil + return &kv.MPPBuildTasksRequest{KeyRanges: kvRanges.FirstPartitionRange()}, nil } diff --git a/planner/core/integration_partition_test.go b/planner/core/integration_partition_test.go index 7823f18474ad1..f1b915b66d038 100644 --- a/planner/core/integration_partition_test.go +++ b/planner/core/integration_partition_test.go @@ -1458,12 +1458,12 @@ func TestRangeColumnsExpr(t *testing.T) { "TableReader 1.14 root partition:p5,p12 data:Selection", "└─Selection 1.14 cop[tikv] in(rce.t.a, 4, 14), in(rce.t.b, NULL, 10)", " └─TableFullScan 21.00 cop[tikv] table:t keep order:false")) - tk.MustQuery(`select * from tref where a in (4,14) and b in (null,10)`).Check(testkit.Rows( - "4 10 3", - "14 10 4")) - tk.MustQuery(`select * from t where a in (4,14) and b in (null,10)`).Check(testkit.Rows( - "4 10 3", - "14 10 4")) + tk.MustQuery(`select * from tref where a in (4,14) and b in (null,10)`).Sort().Check(testkit.Rows( + "14 10 4", + "4 10 3")) + tk.MustQuery(`select * from t where a in (4,14) and b in (null,10)`).Sort().Check(testkit.Rows( + "14 10 4", + "4 10 3")) tk.MustQuery(`explain format = 'brief' select * from t where a in (4,14) and (b in (11,10) OR b is null)`).Check(testkit.Rows( "TableReader 3.43 root partition:p1,p5,p6,p11,p12 data:Selection", "└─Selection 3.43 cop[tikv] in(rce.t.a, 4, 14), or(in(rce.t.b, 11, 10), isnull(rce.t.b))", diff --git a/planner/core/partition_pruner_test.go b/planner/core/partition_pruner_test.go index 84d51524c682a..f037d7fe887e7 100644 --- a/planner/core/partition_pruner_test.go +++ b/planner/core/partition_pruner_test.go @@ -312,15 +312,15 @@ func TestListPartitionPruner(t *testing.T) { for i, tt := range input { testdata.OnRecord(func() { output[i].SQL = tt - output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows()) output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format = 'brief' " + tt).Rows()) }) tk.MustQuery("explain format = 'brief' " + tt).Check(testkit.Rows(output[i].Plan...)) - result := tk.MustQuery(tt) + result := tk.MustQuery(tt).Sort() result.Check(testkit.Rows(output[i].Result...)) // If the query doesn't specified the partition, compare the result with normal table if !strings.Contains(tt, "partition(") { - result.Check(tk2.MustQuery(tt).Rows()) + result.Check(tk.MustQuery(tt).Sort().Rows()) valid = true } require.True(t, valid) @@ -393,7 +393,7 @@ func TestListColumnsPartitionPruner(t *testing.T) { indexPlanTree := testdata.ConvertRowsToStrings(indexPlan.Rows()) testdata.OnRecord(func() { output[i].SQL = tt.SQL - output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt.SQL).Rows()) + output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt.SQL).Sort().Rows()) // Test for table without index. output[i].Plan = planTree // Test for table with index. @@ -408,14 +408,14 @@ func TestListColumnsPartitionPruner(t *testing.T) { checkPrunePartitionInfo(t, tt.SQL, tt.Pruner, indexPlanTree) // compare the result. - result := tk.MustQuery(tt.SQL) + result := tk.MustQuery(tt.SQL).Sort() idxResult := tk1.MustQuery(tt.SQL) - result.Check(idxResult.Rows()) + result.Check(idxResult.Sort().Rows()) result.Check(testkit.Rows(output[i].Result...)) // If the query doesn't specified the partition, compare the result with normal table if !strings.Contains(tt.SQL, "partition(") { - result.Check(tk2.MustQuery(tt.SQL).Rows()) + result.Check(tk2.MustQuery(tt.SQL).Sort().Rows()) valid = true } } diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 5146e63e42c25..ced23204f639d 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -677,7 +677,12 @@ type PhysicalIndexScan struct { // tblColHists contains all columns before pruning, which are used to calculate row-size tblColHists *statistics.HistColl pkIsHandleCol *expression.Column - prop *property.PhysicalProperty + + // constColsByCond records the constant part of the index columns caused by the access conds. + // e.g. the index is (a, b, c) and there's filter a = 1 and b = 2, then the column a and b are const part. + constColsByCond []bool + + prop *property.PhysicalProperty } // Clone implements PhysicalPlan interface. diff --git a/planner/core/task.go b/planner/core/task.go index c446b4ea27696..ac3d10a323a7b 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -24,11 +24,13 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/charset" + "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" @@ -959,6 +961,10 @@ func (p *PhysicalTopN) attach2Task(tasks ...task) task { } needPushDown := len(cols) > 0 if copTask, ok := t.(*copTask); ok && needPushDown && p.canPushDown(copTask.getStoreType()) && len(copTask.rootTaskConds) == 0 { + newTask, changed := p.pushTopNDownToDynamicPartition(copTask) + if changed { + return newTask + } // If all columns in topN are from index plan, we push it to index plan, otherwise we finish the index plan and // push it to table plan. var pushedDownTopN *PhysicalTopN @@ -978,6 +984,138 @@ func (p *PhysicalTopN) attach2Task(tasks ...task) task { return attachPlan2Task(p, rootTask) } +// pushTopNDownToDynamicPartition is a temp solution for partition table. It actually does the same thing as DataSource's isMatchProp. +// We need to support a more enhanced read strategy in the execution phase. So that we can achieve Limit(TiDB)->Reader(TiDB)->Limit(TiKV/TiFlash)->Scan(TiKV/TiFlash). +// Before that is done, we use this logic to provide a way to keep the order property when reading from TiKV, so that we can use the orderliness of index to speed up the query. +// Here we can change the execution plan to TopN(TiDB)->Reader(TiDB)->Limit(TiKV)->Scan(TiKV).(TiFlash is not supported). +func (p *PhysicalTopN) pushTopNDownToDynamicPartition(copTsk *copTask) (task, bool) { + copTsk = copTsk.copy().(*copTask) + if len(copTsk.rootTaskConds) > 0 { + return nil, false + } + colsProp, ok := GetPropByOrderByItems(p.ByItems) + if !ok { + return nil, false + } + allSameOrder, isDesc := colsProp.AllSameOrder() + if !allSameOrder { + return nil, false + } + checkIndexMatchProp := func(idxCols []*expression.Column, idxColLens []int, constColsByCond []bool, colsProp *property.PhysicalProperty) bool { + // If the number of the by-items is bigger than the index columns. We cannot push down since it must not keep order. + if len(idxCols) < len(colsProp.SortItems) { + return false + } + idxPos := 0 + for _, byItem := range colsProp.SortItems { + found := false + for ; idxPos < len(idxCols); idxPos++ { + if idxColLens[idxPos] == types.UnspecifiedLength && idxCols[idxPos].Equal(p.SCtx(), byItem.Col) { + found = true + idxPos++ + break + } + if len(constColsByCond) == 0 || idxPos > len(constColsByCond) || !constColsByCond[idxPos] { + found = false + break + } + } + if !found { + return false + } + } + return true + } + var ( + idxScan *PhysicalIndexScan + tblScan *PhysicalTableScan + tblInfo *model.TableInfo + err error + ) + if copTsk.indexPlan != nil { + copTsk.indexPlan, err = copTsk.indexPlan.Clone() + if err != nil { + return nil, false + } + finalIdxScanPlan := copTsk.indexPlan + for len(finalIdxScanPlan.Children()) > 0 && finalIdxScanPlan.Children()[0] != nil { + finalIdxScanPlan = finalIdxScanPlan.Children()[0] + } + idxScan = finalIdxScanPlan.(*PhysicalIndexScan) + tblInfo = idxScan.Table + } + if copTsk.tablePlan != nil { + copTsk.tablePlan, err = copTsk.tablePlan.Clone() + if err != nil { + return nil, false + } + finalTblScanPlan := copTsk.tablePlan + for len(finalTblScanPlan.Children()) > 0 { + finalTblScanPlan = finalTblScanPlan.Children()[0] + } + tblScan = finalTblScanPlan.(*PhysicalTableScan) + tblInfo = tblScan.Table + } + + pi := tblInfo.GetPartitionInfo() + if pi == nil { + return nil, false + } + if pi.Type == model.PartitionTypeList { + return nil, false + } + + if !copTsk.indexPlanFinished { + // If indexPlan side isn't finished, there's no selection on the table side. + + propMatched := checkIndexMatchProp(idxScan.IdxCols, idxScan.IdxColLens, idxScan.constColsByCond, colsProp) + if !propMatched { + return nil, false + } + + idxScan.Desc = isDesc + childProfile := copTsk.plan().statsInfo() + newCount := p.Offset + p.Count + stats := deriveLimitStats(childProfile, float64(newCount)) + pushedLimit := PhysicalLimit{ + Count: newCount, + }.Init(p.SCtx(), stats, p.SelectBlockOffset()) + pushedLimit.SetSchema(copTsk.indexPlan.Schema()) + copTsk = attachPlan2Task(pushedLimit, copTsk).(*copTask) + } else if copTsk.indexPlan == nil { + if tblScan.HandleCols == nil { + return nil, false + } + + if tblScan.HandleCols.IsInt() { + pk := tblScan.HandleCols.GetCol(0) + if len(colsProp.SortItems) != 1 || !colsProp.SortItems[0].Col.Equal(p.SCtx(), pk) { + return nil, false + } + } else { + idxCols, idxColLens := expression.IndexInfo2PrefixCols(tblScan.Columns, tblScan.Schema().Columns, tables.FindPrimaryIndex(tblScan.Table)) + matched := checkIndexMatchProp(idxCols, idxColLens, nil, colsProp) + if !matched { + return nil, false + } + } + tblScan.Desc = isDesc + childProfile := copTsk.plan().statsInfo() + newCount := p.Offset + p.Count + stats := deriveLimitStats(childProfile, float64(newCount)) + pushedLimit := PhysicalLimit{ + Count: newCount, + }.Init(p.SCtx(), stats, p.SelectBlockOffset()) + pushedLimit.SetSchema(copTsk.tablePlan.Schema()) + copTsk = attachPlan2Task(pushedLimit, copTsk).(*copTask) + } else { + return nil, false + } + + rootTask := copTsk.convertToRootTask(p.ctx) + return attachPlan2Task(p, rootTask), true +} + func (p *PhysicalProjection) attach2Task(tasks ...task) task { t := tasks[0].copy() if cop, ok := t.(*copTask); ok { diff --git a/planner/core/testdata/partition_pruner_out.json b/planner/core/testdata/partition_pruner_out.json index 3206c3dc5853d..0da53482a34cc 100644 --- a/planner/core/testdata/partition_pruner_out.json +++ b/planner/core/testdata/partition_pruner_out.json @@ -847,9 +847,9 @@ { "SQL": "select * from t7 where a is null or a > 0 order by a;", "Result": [ - "", "1", - "2" + "2", + "" ], "Plan": [ "Sort 3343.33 root test_partition.t7.a", @@ -866,8 +866,8 @@ { "SQL": "select * from t1 order by id,a", "Result": [ - " 10 ", "1 1 1", + "10 10 10", "2 2 2", "3 3 3", "4 4 4", @@ -876,7 +876,7 @@ "7 7 7", "8 8 8", "9 9 9", - "10 10 10" + " 10 " ], "Plan": [ "Sort 10000.00 root test_partition.t1.id, test_partition.t1.a", @@ -1341,8 +1341,8 @@ { "SQL": "select * from t1 where a = 1 or true order by id,a", "Result": [ - " 10 ", "1 1 1", + "10 10 10", "2 2 2", "3 3 3", "4 4 4", @@ -1351,7 +1351,7 @@ "7 7 7", "8 8 8", "9 9 9", - "10 10 10" + " 10 " ], "Plan": [ "Sort 10000.00 root test_partition.t1.id, test_partition.t1.a", @@ -1973,13 +1973,13 @@ "SQL": "select * from t1 where a < 3 or b > 4", "Result": [ "1 1 1", + "10 10 10", "2 2 2", "5 5 5", "6 6 6", "7 7 7", "8 8 8", - "9 9 9", - "10 10 10" + "9 9 9" ], "Plan": [ "TableReader 5548.89 root partition:p0,p1 data:Selection", @@ -2059,11 +2059,11 @@ "SQL": "select * from t1 where (a<=1 and b<=1) or (a >=6 and b>=6)", "Result": [ "1 1 1", + "10 10 10", "6 6 6", "7 7 7", "8 8 8", - "9 9 9", - "10 10 10" + "9 9 9" ], "Plan": [ "TableReader 2092.85 root partition:p0,p1 data:Selection", @@ -2080,6 +2080,7 @@ "SQL": "select * from t1 where a <= 100 and b <= 100", "Result": [ "1 1 1", + "10 10 10", "2 2 2", "3 3 3", "4 4 4", @@ -2087,8 +2088,7 @@ "6 6 6", "7 7 7", "8 8 8", - "9 9 9", - "10 10 10" + "9 9 9" ], "Plan": [ "TableReader 1104.45 root partition:p0,p1 data:Selection", @@ -2126,10 +2126,10 @@ { "SQL": "select * from t1 left join t2 on true where (t1.a <=1 or t1.a <= 3 and (t1.b >=3 and t1.b <= 5)) and (t2.a >= 6 and t2.a <= 8) and t2.b>=7 and t2.id>=7 order by t1.id,t1.a", "Result": [ - "1 1 1 8 8 8", "1 1 1 7 7 7", - "3 3 3 8 8 8", - "3 3 3 7 7 7" + "1 1 1 8 8 8", + "3 3 3 7 7 7", + "3 3 3 8 8 8" ], "Plan": [ "Sort 93855.70 root test_partition.t1.id, test_partition.t1.a", @@ -2326,8 +2326,8 @@ { "SQL": "select * from t1 where a = 3 or true order by id,a", "Result": [ - " 10 ", "1 1 1", + "10 10 10", "2 2 2", "3 3 3", "4 4 4", @@ -2336,7 +2336,7 @@ "7 7 7", "8 8 8", "9 9 9", - "10 10 10" + " 10 " ], "Plan": [ "Sort 10000.00 root test_partition.t1.id, test_partition.t1.a", @@ -2463,6 +2463,7 @@ "SQL": "select * from t1 where (a >= 1 and a <= 6) or (a>=3 and b >=3)", "Result": [ "1 1 1", + "10 10 10", "2 2 2", "3 3 3", "4 4 4", @@ -2470,8 +2471,7 @@ "6 6 6", "7 7 7", "8 8 8", - "9 9 9", - "10 10 10" + "9 9 9" ], "Plan": [ "TableReader 1333.33 root partition:p0,p1 data:Selection", diff --git a/store/copr/batch_coprocessor.go b/store/copr/batch_coprocessor.go index bfd3bbcc94fdd..5f6e435028e3b 100644 --- a/store/copr/batch_coprocessor.go +++ b/store/copr/batch_coprocessor.go @@ -695,7 +695,8 @@ func (c *CopClient) sendBatch(ctx context.Context, req *kv.Request, vars *tikv.V } tasks, err = buildBatchCopTasksForPartitionedTable(bo, c.store.kvStore, keyRanges, req.StoreType, nil, 0, false, 0, partitionIDs) } else { - ranges := NewKeyRanges(req.KeyRanges) + // TODO: merge the if branch. + ranges := NewKeyRanges(req.KeyRanges.FirstPartitionRange()) tasks, err = buildBatchCopTasksForNonPartitionedTable(bo, c.store.kvStore, ranges, req.StoreType, nil, 0, false, 0) } diff --git a/store/copr/copr_test/coprocessor_test.go b/store/copr/copr_test/coprocessor_test.go index f92db7ba7c334..208f2e2bd2190 100644 --- a/store/copr/copr_test/coprocessor_test.go +++ b/store/copr/copr_test/coprocessor_test.go @@ -43,7 +43,7 @@ func TestBuildCopIteratorWithRowCountHint(t *testing.T) { req := &kv.Request{ Tp: kv.ReqTypeDAG, - KeyRanges: copr.BuildKeyRanges("a", "c", "d", "e", "h", "x", "y", "z"), + KeyRanges: kv.NewNonParitionedKeyRanges(copr.BuildKeyRanges("a", "c", "d", "e", "h", "x", "y", "z")), FixedRowCountHint: []int{1, 1, 3, copr.CopSmallTaskRow}, Concurrency: 15, } @@ -57,7 +57,7 @@ func TestBuildCopIteratorWithRowCountHint(t *testing.T) { req = &kv.Request{ Tp: kv.ReqTypeDAG, - KeyRanges: copr.BuildKeyRanges("a", "c", "d", "e", "h", "x", "y", "z"), + KeyRanges: kv.NewNonParitionedKeyRanges(copr.BuildKeyRanges("a", "c", "d", "e", "h", "x", "y", "z")), FixedRowCountHint: []int{1, 1, 3, 3}, Concurrency: 15, } @@ -72,7 +72,7 @@ func TestBuildCopIteratorWithRowCountHint(t *testing.T) { // cross-region long range req = &kv.Request{ Tp: kv.ReqTypeDAG, - KeyRanges: copr.BuildKeyRanges("a", "z"), + KeyRanges: kv.NewNonParitionedKeyRanges(copr.BuildKeyRanges("a", "z")), FixedRowCountHint: []int{10}, Concurrency: 15, } @@ -86,7 +86,7 @@ func TestBuildCopIteratorWithRowCountHint(t *testing.T) { req = &kv.Request{ Tp: kv.ReqTypeDAG, - KeyRanges: copr.BuildKeyRanges("a", "z"), + KeyRanges: kv.NewNonParitionedKeyRanges(copr.BuildKeyRanges("a", "z")), FixedRowCountHint: []int{copr.CopSmallTaskRow + 1}, Concurrency: 15, } diff --git a/store/copr/coprocessor.go b/store/copr/coprocessor.go index 982e981f24c79..fdff94a09717d 100644 --- a/store/copr/coprocessor.go +++ b/store/copr/coprocessor.go @@ -15,7 +15,6 @@ package copr import ( - "bytes" "context" "fmt" "math" @@ -53,7 +52,6 @@ import ( "github.com/tikv/client-go/v2/txnkv/txnsnapshot" "github.com/tikv/client-go/v2/util" "go.uber.org/zap" - "golang.org/x/exp/slices" ) var coprCacheCounterEvict = tidbmetrics.DistSQLCoprCacheCounter.WithLabelValues("evict") @@ -121,10 +119,7 @@ func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars } failpoint.Inject("checkKeyRangeSortedForPaging", func(_ failpoint.Value) { if req.Paging.Enable { - isSorted := slices.IsSortedFunc(req.KeyRanges, func(i, j kv.KeyRange) bool { - return bytes.Compare(i.StartKey, j.StartKey) < 0 - }) - if !isSorted { + if !req.KeyRanges.IsFullySorted() { logutil.BgLogger().Fatal("distsql request key range not sorted!") } } @@ -138,8 +133,27 @@ func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars }) bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, vars) - ranges := NewKeyRanges(req.KeyRanges) - tasks, err := buildCopTasks(bo, c.store.GetRegionCache(), ranges, req, eventCb) + var ( + tasks []*copTask + err error + ) + buildTaskFunc := func(ranges []kv.KeyRange) error { + keyRanges := NewKeyRanges(ranges) + tasksFromRanges, err := buildCopTasks(bo, c.store.GetRegionCache(), keyRanges, req, eventCb) + if err != nil { + return err + } + if len(tasks) == 0 { + tasks = tasksFromRanges + return nil + } + tasks = append(tasks, tasksFromRanges...) + return nil + } + // Here we build the task by partition, not directly by region. + // This is because it's possible that TiDB merge multiple small partition into one region which break some assumption. + // Keep it split by partition would be more safe. + err = req.KeyRanges.ForEachPartitionWithErr(buildTaskFunc) reqType := "null" if req.ClosestReplicaReadAdjuster != nil { reqType = "miss" diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index e45576b9d0674..c2d98f5a2b17e 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -1627,3 +1627,30 @@ func IndexKVIsUnique(value []byte) bool { segs := SplitIndexValue(value) return segs.IntHandle != nil || segs.CommonHandle != nil } + +// VerifyTableIDForRanges verifies that all given ranges are valid to decode the table id. +func VerifyTableIDForRanges(keyRanges *kv.KeyRanges) ([]int64, error) { + tids := make([]int64, 0, keyRanges.PartitionNum()) + collectFunc := func(ranges []kv.KeyRange) error { + if len(ranges) == 0 { + return nil + } + tid := DecodeTableID(ranges[0].StartKey) + if tid <= 0 { + return errors.New("Incorrect keyRange is constrcuted") + } + tids = append(tids, tid) + for i := 1; i < len(ranges); i++ { + tmpTID := DecodeTableID(ranges[i].StartKey) + if tmpTID <= 0 { + return errors.New("Incorrect keyRange is constrcuted") + } + if tid != tmpTID { + return errors.Errorf("Using multi partition's ranges as single table's") + } + } + return nil + } + err := keyRanges.ForEachPartitionWithErr(collectFunc) + return tids, err +} diff --git a/testkit/result.go b/testkit/result.go index 0f7ad0ce53cbc..210d32d4c57b9 100644 --- a/testkit/result.go +++ b/testkit/result.go @@ -49,6 +49,11 @@ func (res *Result) Check(expected [][]interface{}) { res.require.Equal(needBuff.String(), resBuff.String(), res.comment) } +// AddComment adds the extra comment for the Result's output. +func (res *Result) AddComment(c string) { + res.comment += "\n" + c +} + // CheckWithFunc asserts the result match the expected results in the way `f` specifies. func (res *Result) CheckWithFunc(expected [][]interface{}, f func([]string, []interface{}) bool) { res.require.Equal(len(res.rows), len(expected), res.comment+"\nResult length mismatch") diff --git a/ttl/BUILD.bazel b/ttl/BUILD.bazel index e6a76c69d8df5..e5ec05168b29b 100644 --- a/ttl/BUILD.bazel +++ b/ttl/BUILD.bazel @@ -3,17 +3,25 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "ttl", srcs = [ + "session.go", "sql.go", "table.go", ], importpath = "github.com/pingcap/tidb/ttl", visibility = ["//visibility:public"], deps = [ + "//infoschema", + "//kv", "//parser/ast", "//parser/format", "//parser/model", "//parser/mysql", + "//parser/terror", + "//sessionctx", + "//sessiontxn", + "//table/tables", "//types", + "//util/chunk", "//util/sqlexec", "@com_github_pingcap_errors//:errors", "@com_github_pkg_errors//:errors", @@ -24,11 +32,13 @@ go_test( name = "ttl_test", srcs = [ "main_test.go", + "session_test.go", "sql_test.go", + "table_test.go", ], + embed = [":ttl"], flaky = True, deps = [ - ":ttl", "//kv", "//parser", "//parser/ast", @@ -38,6 +48,7 @@ go_test( "//testkit/testsetup", "//types", "//util/sqlexec", + "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", "@org_uber_go_goleak//:goleak", ], diff --git a/ttl/session.go b/ttl/session.go new file mode 100644 index 0000000000000..b3321e0d53c06 --- /dev/null +++ b/ttl/session.go @@ -0,0 +1,123 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttl + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessiontxn" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sqlexec" +) + +// Session is used to execute queries for TTL case +type Session interface { + sessionctx.Context + // SessionInfoSchema returns information schema of current session + SessionInfoSchema() infoschema.InfoSchema + // ExecuteSQL executes the sql + ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) + // RunInTxn executes the specified function in a txn + RunInTxn(ctx context.Context, fn func() error) (err error) + // Close closes the session + Close() +} + +type session struct { + sessionctx.Context + sqlExec sqlexec.SQLExecutor + closeFn func() +} + +// NewSession creates a new Session +func NewSession(sctx sessionctx.Context, sqlExec sqlexec.SQLExecutor, closeFn func()) Session { + return &session{ + Context: sctx, + sqlExec: sqlExec, + closeFn: closeFn, + } +} + +// SessionInfoSchema returns information schema of current session +func (s *session) SessionInfoSchema() infoschema.InfoSchema { + if s.Context == nil { + return nil + } + return sessiontxn.GetTxnManager(s.Context).GetTxnInfoSchema() +} + +// ExecuteSQL executes the sql +func (s *session) ExecuteSQL(ctx context.Context, sql string, args ...interface{}) ([]chunk.Row, error) { + if s.sqlExec == nil { + return nil, errors.New("session is closed") + } + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnTTL) + rs, err := s.sqlExec.ExecuteInternal(ctx, sql, args...) + if err != nil { + return nil, err + } + + if rs == nil { + return nil, nil + } + + defer func() { + terror.Log(rs.Close()) + }() + + return sqlexec.DrainRecordSet(ctx, rs, 8) +} + +// RunInTxn executes the specified function in a txn +func (s *session) RunInTxn(ctx context.Context, fn func() error) (err error) { + if _, err = s.ExecuteSQL(ctx, "BEGIN"); err != nil { + return err + } + + success := false + defer func() { + if !success { + _, err = s.ExecuteSQL(ctx, "ROLLBACK") + terror.Log(err) + } + }() + + if err = fn(); err != nil { + return err + } + + if _, err = s.ExecuteSQL(ctx, "COMMIT"); err != nil { + return err + } + + success = true + return err +} + +// Close closes the session +func (s *session) Close() { + if s.closeFn != nil { + s.closeFn() + s.Context = nil + s.sqlExec = nil + s.closeFn = nil + } +} diff --git a/ttl/session_test.go b/ttl/session_test.go new file mode 100644 index 0000000000000..90d47ed313e73 --- /dev/null +++ b/ttl/session_test.go @@ -0,0 +1,52 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttl + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +func TestSessionRunInTxn(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(id int primary key, v int)") + se := NewSession(tk.Session(), tk.Session(), nil) + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + + require.NoError(t, se.RunInTxn(context.TODO(), func() error { + tk.MustExec("insert into t values (1, 10)") + return nil + })) + tk2.MustQuery("select * from t order by id asc").Check(testkit.Rows("1 10")) + + require.NoError(t, se.RunInTxn(context.TODO(), func() error { + tk.MustExec("insert into t values (2, 20)") + return errors.New("err") + })) + tk2.MustQuery("select * from t order by id asc").Check(testkit.Rows("1 10")) + + require.NoError(t, se.RunInTxn(context.TODO(), func() error { + tk.MustExec("insert into t values (3, 30)") + return nil + })) + tk2.MustQuery("select * from t order by id asc").Check(testkit.Rows("1 10", "3 30")) +} diff --git a/ttl/sql.go b/ttl/sql.go index 9cdf762846d4d..3d100fd62eee7 100644 --- a/ttl/sql.go +++ b/ttl/sql.go @@ -31,6 +31,8 @@ import ( "github.com/pkg/errors" ) +const dateTimeFormat = "2006-01-02 15:04:05.999999" + func writeHex(in io.Writer, d types.Datum) error { _, err := fmt.Fprintf(in, "x'%s'", hex.EncodeToString(d.GetBytes())) return err @@ -179,7 +181,7 @@ func (b *SQLBuilder) WriteExpireCondition(expire time.Time) error { b.writeColNames([]*model.ColumnInfo{b.tbl.TimeColumn}, false) b.restoreCtx.WritePlain(" < ") b.restoreCtx.WritePlain("'") - b.restoreCtx.WritePlain(expire.Format("2006-01-02 15:04:05.999999")) + b.restoreCtx.WritePlain(expire.Format(dateTimeFormat)) b.restoreCtx.WritePlain("'") b.hasWriteExpireCond = true return nil diff --git a/ttl/table.go b/ttl/table.go index b9c59a34e5c17..4885da0e137b4 100644 --- a/ttl/table.go +++ b/ttl/table.go @@ -15,23 +15,115 @@ package ttl import ( + "context" + "fmt" + "time" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" ) +func getTableKeyColumns(tbl *model.TableInfo) ([]*model.ColumnInfo, []*types.FieldType, error) { + if tbl.PKIsHandle { + for i, col := range tbl.Columns { + if mysql.HasPriKeyFlag(col.GetFlag()) { + return []*model.ColumnInfo{tbl.Columns[i]}, []*types.FieldType{&tbl.Columns[i].FieldType}, nil + } + } + return nil, nil, errors.Errorf("Cannot find primary key for table: %s", tbl.Name) + } + + if tbl.IsCommonHandle { + idxInfo := tables.FindPrimaryIndex(tbl) + columns := make([]*model.ColumnInfo, len(idxInfo.Columns)) + fieldTypes := make([]*types.FieldType, len(idxInfo.Columns)) + for i, idxCol := range idxInfo.Columns { + columns[i] = tbl.Columns[idxCol.Offset] + fieldTypes[i] = &tbl.Columns[idxCol.Offset].FieldType + } + return columns, fieldTypes, nil + } + + extraHandleColInfo := model.NewExtraHandleColInfo() + return []*model.ColumnInfo{extraHandleColInfo}, []*types.FieldType{&extraHandleColInfo.FieldType}, nil +} + // PhysicalTable is used to provide some information for a physical table in TTL job type PhysicalTable struct { + // Schema is the database name of the table Schema model.CIStr *model.TableInfo + // Partition is the partition name + Partition model.CIStr // PartitionDef is the partition definition PartitionDef *model.PartitionDefinition // KeyColumns is the cluster index key columns for the table KeyColumns []*model.ColumnInfo + // KeyColumnTypes is the types of the key columns + KeyColumnTypes []*types.FieldType // TimeColum is the time column used for TTL TimeColumn *model.ColumnInfo } +// NewPhysicalTable create a new PhysicalTable +func NewPhysicalTable(schema model.CIStr, tbl *model.TableInfo, partition model.CIStr) (*PhysicalTable, error) { + if tbl.State != model.StatePublic { + return nil, errors.Errorf("table '%s.%s' is not a public table", schema, tbl.Name) + } + + ttlInfo := tbl.TTLInfo + if ttlInfo == nil { + return nil, errors.Errorf("table '%s.%s' is not a ttl table", schema, tbl.Name) + } + + timeColumn := tbl.FindPublicColumnByName(ttlInfo.ColumnName.L) + if timeColumn == nil { + return nil, errors.Errorf("time column '%s' is not public in ttl table '%s.%s'", ttlInfo.ColumnName, schema, tbl.Name) + } + + keyColumns, keyColumTypes, err := getTableKeyColumns(tbl) + if err != nil { + return nil, err + } + + var partitionDef *model.PartitionDefinition + if tbl.Partition == nil { + if partition.L != "" { + return nil, errors.Errorf("table '%s.%s' is not a partitioned table", schema, tbl.Name) + } + } else { + if partition.L == "" { + return nil, errors.Errorf("partition name is required, table '%s.%s' is a partitioned table", schema, tbl.Name) + } + + for i := range tbl.Partition.Definitions { + def := &tbl.Partition.Definitions[i] + if def.Name.L == partition.L { + partitionDef = def + } + } + + if partitionDef == nil { + return nil, errors.Errorf("partition '%s' is not found in ttl table '%s.%s'", partition.O, schema, tbl.Name) + } + } + + return &PhysicalTable{ + Schema: schema, + TableInfo: tbl, + Partition: partition, + PartitionDef: partitionDef, + KeyColumns: keyColumns, + KeyColumnTypes: keyColumTypes, + TimeColumn: timeColumn, + }, nil +} + // ValidateKey validates a key func (t *PhysicalTable) ValidateKey(key []types.Datum) error { if len(t.KeyColumns) != len(key) { @@ -39,3 +131,25 @@ func (t *PhysicalTable) ValidateKey(key []types.Datum) error { } return nil } + +// EvalExpireTime returns the expired time +func (t *PhysicalTable) EvalExpireTime(ctx context.Context, se Session, now time.Time) (expire time.Time, err error) { + tz := se.GetSessionVars().TimeZone + + expireExpr := t.TTLInfo.IntervalExprStr + unit := ast.TimeUnitType(t.TTLInfo.IntervalTimeUnit) + + var rows []chunk.Row + rows, err = se.ExecuteSQL( + ctx, + // FROM_UNIXTIME does not support negative value, so we use `FROM_UNIXTIME(0) + INTERVAL ` to present current time + fmt.Sprintf("SELECT FROM_UNIXTIME(0) + INTERVAL %d SECOND - INTERVAL %s %s", now.Unix(), expireExpr, unit.String()), + ) + + if err != nil { + return + } + + tm := rows[0].GetTime(0) + return tm.CoreTime().GoTime(tz) +} diff --git a/ttl/table_test.go b/ttl/table_test.go new file mode 100644 index 0000000000000..f77556c98dc09 --- /dev/null +++ b/ttl/table_test.go @@ -0,0 +1,213 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttl_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/ttl" + "github.com/stretchr/testify/require" +) + +func TestNewTTLTable(t *testing.T) { + cases := []struct { + db string + tbl string + def string + timeCol string + keyCols []string + }{ + { + db: "test", + tbl: "t1", + def: "(a int)", + }, + { + db: "test", + tbl: "ttl1", + def: "(a int, t datetime) ttl = `t` + interval 2 hour", + timeCol: "t", + keyCols: []string{"_tidb_rowid"}, + }, + { + db: "test", + tbl: "ttl2", + def: "(id int primary key, t datetime) ttl = `t` + interval 3 hour", + timeCol: "t", + keyCols: []string{"id"}, + }, + { + db: "test", + tbl: "ttl3", + def: "(a int, b varchar(32), c binary(32), t datetime, primary key (a, b, c)) ttl = `t` + interval 1 month", + timeCol: "t", + keyCols: []string{"a", "b", "c"}, + }, + { + db: "test", + tbl: "ttl4", + def: "(id int primary key, t datetime) " + + "ttl = `t` + interval 1 day " + + "PARTITION BY RANGE (id) (" + + " PARTITION p0 VALUES LESS THAN (10)," + + " PARTITION p1 VALUES LESS THAN (100)," + + " PARTITION p2 VALUES LESS THAN (1000)," + + " PARTITION p3 VALUES LESS THAN MAXVALUE)", + timeCol: "t", + keyCols: []string{"id"}, + }, + { + db: "test", + tbl: "ttl5", + def: "(id int primary key nonclustered, t datetime) ttl = `t` + interval 3 hour", + timeCol: "t", + keyCols: []string{"_tidb_rowid"}, + }, + } + + store, do := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + + for _, c := range cases { + tk.MustExec("use " + c.db) + tk.MustExec("create table " + c.tbl + c.def) + } + + for _, c := range cases { + is := do.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr(c.db), model.NewCIStr(c.tbl)) + require.NoError(t, err) + tblInfo := tbl.Meta() + var physicalTbls []*ttl.PhysicalTable + if tblInfo.Partition == nil { + ttlTbl, err := ttl.NewPhysicalTable(model.NewCIStr(c.db), tblInfo, model.NewCIStr("")) + if c.timeCol == "" { + require.Error(t, err) + continue + } + require.NoError(t, err) + physicalTbls = append(physicalTbls, ttlTbl) + } else { + for _, partition := range tblInfo.Partition.Definitions { + ttlTbl, err := ttl.NewPhysicalTable(model.NewCIStr(c.db), tblInfo, model.NewCIStr(partition.Name.O)) + if c.timeCol == "" { + require.Error(t, err) + continue + } + require.NoError(t, err) + physicalTbls = append(physicalTbls, ttlTbl) + } + if c.timeCol == "" { + continue + } + } + + for i, ttlTbl := range physicalTbls { + require.Equal(t, c.db, ttlTbl.Schema.O) + require.Same(t, tblInfo, ttlTbl.TableInfo) + timeColumn := tblInfo.FindPublicColumnByName(c.timeCol) + require.NotNil(t, timeColumn) + require.Same(t, timeColumn, ttlTbl.TimeColumn) + + if tblInfo.Partition == nil { + require.Equal(t, "", ttlTbl.Partition.L) + require.Nil(t, ttlTbl.PartitionDef) + } else { + def := tblInfo.Partition.Definitions[i] + require.Equal(t, def.Name.L, ttlTbl.Partition.L) + require.Equal(t, def, *(ttlTbl.PartitionDef)) + } + + require.Equal(t, len(c.keyCols), len(ttlTbl.KeyColumns)) + require.Equal(t, len(c.keyCols), len(ttlTbl.KeyColumnTypes)) + + for j, keyCol := range c.keyCols { + msg := fmt.Sprintf("%s, col: %s", c.tbl, keyCol) + var col *model.ColumnInfo + if keyCol == model.ExtraHandleName.L { + col = model.NewExtraHandleColInfo() + } else { + col = tblInfo.FindPublicColumnByName(keyCol) + } + colJ := ttlTbl.KeyColumns[j] + colFieldJ := ttlTbl.KeyColumnTypes[j] + + require.NotNil(t, col, msg) + require.Equal(t, col.ID, colJ.ID, msg) + require.Equal(t, col.Name.L, colJ.Name.L, msg) + require.Equal(t, col.FieldType, colJ.FieldType, msg) + require.Equal(t, col.FieldType, *colFieldJ, msg) + } + } + } +} + +func TestEvalTTLExpireTime(t *testing.T) { + store, do := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t(a int, t datetime) ttl = `t` + interval 1 day") + tk.MustExec("create table test.t2(a int, t datetime) ttl = `t` + interval 3 month") + + tb, err := do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + tblInfo := tb.Meta() + ttlTbl, err := ttl.NewPhysicalTable(model.NewCIStr("test"), tblInfo, model.NewCIStr("")) + require.NoError(t, err) + + tb2, err := do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t2")) + require.NoError(t, err) + tblInfo2 := tb2.Meta() + ttlTbl2, err := ttl.NewPhysicalTable(model.NewCIStr("test"), tblInfo2, model.NewCIStr("")) + require.NoError(t, err) + + se := ttl.NewSession(tk.Session(), tk.Session(), nil) + + now := time.UnixMilli(0) + tz1, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + tz2, err := time.LoadLocation("Europe/Berlin") + require.NoError(t, err) + + se.GetSessionVars().TimeZone = tz1 + tm, err := ttlTbl.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix()) + require.Equal(t, "1969-12-31 08:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz1.String(), tm.Location().String()) + + se.GetSessionVars().TimeZone = tz2 + tm, err = ttlTbl.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix()) + require.Equal(t, "1969-12-31 01:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz2.String(), tm.Location().String()) + + se.GetSessionVars().TimeZone = tz1 + tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "1969-10-01 08:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz1.String(), tm.Location().String()) + + se.GetSessionVars().TimeZone = tz2 + tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "1969-10-01 01:00:00", tm.Format("2006-01-02 15:04:05")) + require.Equal(t, tz2.String(), tm.Location().String()) +}