diff --git a/WORKSPACE b/WORKSPACE index 04b8cb3e31341..435a339dc8269 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -35,7 +35,7 @@ go_download_sdk( "https://mirrors.aliyun.com/golang/{}", "https://dl.google.com/go/{}", ], - version = "1.19.5", + version = "1.19.8", ) go_register_toolchains( diff --git a/br/pkg/conn/conn.go b/br/pkg/conn/conn.go index 5adbe0a33ab1c..fff775bf1c1d7 100644 --- a/br/pkg/conn/conn.go +++ b/br/pkg/conn/conn.go @@ -281,7 +281,8 @@ func (mgr *Mgr) GetTS(ctx context.Context) (uint64, error) { } // GetMergeRegionSizeAndCount returns the tikv config `coprocessor.region-split-size` and `coprocessor.region-split-key`. -func (mgr *Mgr) GetMergeRegionSizeAndCount(ctx context.Context, client *http.Client) (uint64, uint64, error) { +// returns the default config when failed. +func (mgr *Mgr) GetMergeRegionSizeAndCount(ctx context.Context, client *http.Client) (uint64, uint64) { regionSplitSize := DefaultMergeRegionSizeBytes regionSplitKeys := DefaultMergeRegionKeyCount type coprocessor struct { @@ -310,9 +311,10 @@ func (mgr *Mgr) GetMergeRegionSizeAndCount(ctx context.Context, client *http.Cli return nil }) if err != nil { - return 0, 0, errors.Trace(err) + log.Warn("meet error when getting config from TiKV; using default", logutil.ShortError(err)) + return DefaultMergeRegionSizeBytes, DefaultMergeRegionKeyCount } - return regionSplitSize, regionSplitKeys, nil + return regionSplitSize, regionSplitKeys } // GetConfigFromTiKV get configs from all alive tikv stores. diff --git a/br/pkg/conn/conn_test.go b/br/pkg/conn/conn_test.go index 01ce8bc08203e..fc822fac123d9 100644 --- a/br/pkg/conn/conn_test.go +++ b/br/pkg/conn/conn_test.go @@ -292,6 +292,38 @@ func TestGetMergeRegionSizeAndCount(t *testing.T) { regionSplitSize: DefaultMergeRegionSizeBytes, regionSplitKeys: DefaultMergeRegionKeyCount, }, + { + stores: []*metapb.Store{ + { + Id: 1, + State: metapb.StoreState_Up, + Labels: []*metapb.StoreLabel{ + { + Key: "engine", + Value: "tiflash", + }, + }, + }, + { + Id: 2, + State: metapb.StoreState_Up, + Labels: []*metapb.StoreLabel{ + { + Key: "engine", + Value: "tikv", + }, + }, + }, + }, + content: []string{ + "", + // Assuming the TiKV has failed due to some reason. + "", + }, + // no tikv detected in this case + regionSplitSize: DefaultMergeRegionSizeBytes, + regionSplitKeys: DefaultMergeRegionKeyCount, + }, { stores: []*metapb.Store{ { @@ -388,8 +420,7 @@ func TestGetMergeRegionSizeAndCount(t *testing.T) { httpCli := mockServer.Client() mgr := &Mgr{PdController: &pdutil.PdController{}} mgr.PdController.SetPDClient(pdCli) - rs, rk, err := mgr.GetMergeRegionSizeAndCount(ctx, httpCli) - require.NoError(t, err) + rs, rk := mgr.GetMergeRegionSizeAndCount(ctx, httpCli) require.Equal(t, ca.regionSplitSize, rs) require.Equal(t, ca.regionSplitKeys, rk) mockServer.Close() diff --git a/br/pkg/gluetidb/BUILD.bazel b/br/pkg/gluetidb/BUILD.bazel index 5340729c1d548..eddbd41ee46d4 100644 --- a/br/pkg/gluetidb/BUILD.bazel +++ b/br/pkg/gluetidb/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "gluetidb", @@ -25,3 +25,21 @@ go_library( "@org_uber_go_zap//:zap", ], ) + +go_test( + name = "gluetidb_test", + timeout = "short", + srcs = ["glue_test.go"], + embed = [":gluetidb"], + flaky = True, + deps = [ + "//ddl", + "//kv", + "//meta", + "//parser/model", + "//sessionctx", + "//testkit", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_stretchr_testify//require", + ], +) diff --git a/br/pkg/gluetidb/glue.go b/br/pkg/gluetidb/glue.go index a6a53a91d500b..15c2e5c89a1f3 100644 --- a/br/pkg/gluetidb/glue.go +++ b/br/pkg/gluetidb/glue.go @@ -207,9 +207,34 @@ func (gs *tidbSession) CreatePlacementPolicy(ctx context.Context, policy *model. return d.CreatePlacementPolicyWithInfo(gs.se, policy, ddl.OnExistIgnore) } +// SplitBatchCreateTable provide a way to split batch into small batch when batch size is large than 6 MB. +// The raft entry has limit size of 6 MB, a batch of CreateTables may hit this limitation +// TODO: shall query string be set for each split batch create, it looks does not matter if we set once for all. +func (gs *tidbSession) SplitBatchCreateTable(schema model.CIStr, infos []*model.TableInfo, cs ...ddl.CreateTableWithInfoConfigurier) error { + var err error + d := domain.GetDomain(gs.se).DDL() + + if err = d.BatchCreateTableWithInfo(gs.se, schema, infos, append(cs, ddl.OnExistIgnore)...); kv.ErrEntryTooLarge.Equal(err) { + log.Info("entry too large, split batch create table", zap.Int("num table", len(infos))) + if len(infos) == 1 { + return err + } + mid := len(infos) / 2 + err = gs.SplitBatchCreateTable(schema, infos[:mid], cs...) + if err != nil { + return err + } + err = gs.SplitBatchCreateTable(schema, infos[mid:], cs...) + if err != nil { + return err + } + return nil + } + return err +} + // CreateTables implements glue.BatchCreateTableSession. func (gs *tidbSession) CreateTables(ctx context.Context, tables map[string][]*model.TableInfo, cs ...ddl.CreateTableWithInfoConfigurier) error { - d := domain.GetDomain(gs.se).DDL() var dbName model.CIStr // Disable foreign key check when batch create tables. @@ -237,8 +262,8 @@ func (gs *tidbSession) CreateTables(ctx context.Context, tables map[string][]*mo cloneTables = append(cloneTables, table) } gs.se.SetValue(sessionctx.QueryString, queryBuilder.String()) - err := d.BatchCreateTableWithInfo(gs.se, dbName, cloneTables, append(cs, ddl.OnExistIgnore)...) - if err != nil { + + if err := gs.SplitBatchCreateTable(dbName, cloneTables, cs...); err != nil { //It is possible to failure when TiDB does not support model.ActionCreateTables. //In this circumstance, BatchCreateTableWithInfo returns errno.ErrInvalidDDLJob, //we fall back to old way that creating table one by one diff --git a/br/pkg/gluetidb/glue_test.go b/br/pkg/gluetidb/glue_test.go new file mode 100644 index 0000000000000..e7c2f64dcfaa5 --- /dev/null +++ b/br/pkg/gluetidb/glue_test.go @@ -0,0 +1,208 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gluetidb + +import ( + "context" + "strconv" + "testing" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +// batch create table with table id reused +func TestSplitBatchCreateTableWithTableId(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists table_id_resued1") + tk.MustExec("drop table if exists table_id_resued2") + tk.MustExec("drop table if exists table_id_new") + + d := dom.DDL() + require.NotNil(t, d) + + infos1 := []*model.TableInfo{} + infos1 = append(infos1, &model.TableInfo{ + ID: 124, + Name: model.NewCIStr("table_id_resued1"), + }) + infos1 = append(infos1, &model.TableInfo{ + ID: 125, + Name: model.NewCIStr("table_id_resued2"), + }) + + se := &tidbSession{se: tk.Session()} + + // keep/reused table id verification + tk.Session().SetValue(sessionctx.QueryString, "skip") + err := se.SplitBatchCreateTable(model.NewCIStr("test"), infos1, ddl.AllocTableIDIf(func(ti *model.TableInfo) bool { + return false + })) + require.NoError(t, err) + + tk.MustQuery("select tidb_table_id from information_schema.tables where table_name = 'table_id_resued1'").Check(testkit.Rows("124")) + tk.MustQuery("select tidb_table_id from information_schema.tables where table_name = 'table_id_resued2'").Check(testkit.Rows("125")) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + + // allocate new table id verification + // query the global id + var id int64 + err = kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + m := meta.NewMeta(txn) + var err error + id, err = m.GenGlobalID() + return err + }) + + require.NoError(t, err) + + infos2 := []*model.TableInfo{} + infos2 = append(infos2, &model.TableInfo{ + ID: 124, + Name: model.NewCIStr("table_id_new"), + }) + + tk.Session().SetValue(sessionctx.QueryString, "skip") + err = se.SplitBatchCreateTable(model.NewCIStr("test"), infos2, ddl.AllocTableIDIf(func(ti *model.TableInfo) bool { + return true + })) + require.NoError(t, err) + + idGen, ok := tk.MustQuery("select tidb_table_id from information_schema.tables where table_name = 'table_id_new'").Rows()[0][0].(string) + require.True(t, ok) + idGenNum, err := strconv.ParseInt(idGen, 10, 64) + require.NoError(t, err) + require.Greater(t, idGenNum, id) + + // a empty table info with len(info3) = 0 + infos3 := []*model.TableInfo{} + + err = se.SplitBatchCreateTable(model.NewCIStr("test"), infos3, ddl.AllocTableIDIf(func(ti *model.TableInfo) bool { + return false + })) + require.NoError(t, err) +} + +// batch create table with table id reused +func TestSplitBatchCreateTable(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists table_1") + tk.MustExec("drop table if exists table_2") + tk.MustExec("drop table if exists table_3") + + d := dom.DDL() + require.NotNil(t, d) + + infos := []*model.TableInfo{} + infos = append(infos, &model.TableInfo{ + ID: 1234, + Name: model.NewCIStr("tables_1"), + }) + infos = append(infos, &model.TableInfo{ + ID: 1235, + Name: model.NewCIStr("tables_2"), + }) + infos = append(infos, &model.TableInfo{ + ID: 1236, + Name: model.NewCIStr("tables_3"), + }) + + se := &tidbSession{se: tk.Session()} + + // keep/reused table id verification + tk.Session().SetValue(sessionctx.QueryString, "skip") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/ddl/RestoreBatchCreateTableEntryTooLarge", "return(1)")) + err := se.SplitBatchCreateTable(model.NewCIStr("test"), infos, ddl.AllocTableIDIf(func(ti *model.TableInfo) bool { + return false + })) + + require.NoError(t, err) + tk.MustQuery("show tables like '%tables_%'").Check(testkit.Rows("tables_1", "tables_2", "tables_3")) + jobs := tk.MustQuery("admin show ddl jobs").Rows() + require.Greater(t, len(jobs), 3) + // check table_1 + job1 := jobs[0] + require.Equal(t, "test", job1[1]) + require.Equal(t, "tables_3", job1[2]) + require.Equal(t, "create tables", job1[3]) + require.Equal(t, "public", job1[4]) + + // check table_2 + job2 := jobs[1] + require.Equal(t, "test", job2[1]) + require.Equal(t, "tables_2", job2[2]) + require.Equal(t, "create tables", job2[3]) + require.Equal(t, "public", job2[4]) + + // check table_3 + job3 := jobs[2] + require.Equal(t, "test", job3[1]) + require.Equal(t, "tables_1", job3[2]) + require.Equal(t, "create tables", job3[3]) + require.Equal(t, "public", job3[4]) + + // check reused table id + tk.MustQuery("select tidb_table_id from information_schema.tables where table_name = 'tables_1'").Check(testkit.Rows("1234")) + tk.MustQuery("select tidb_table_id from information_schema.tables where table_name = 'tables_2'").Check(testkit.Rows("1235")) + tk.MustQuery("select tidb_table_id from information_schema.tables where table_name = 'tables_3'").Check(testkit.Rows("1236")) + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/RestoreBatchCreateTableEntryTooLarge")) +} + +// batch create table with table id reused +func TestSplitBatchCreateTableFailWithEntryTooLarge(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists table_1") + tk.MustExec("drop table if exists table_2") + tk.MustExec("drop table if exists table_3") + + d := dom.DDL() + require.NotNil(t, d) + + infos := []*model.TableInfo{} + infos = append(infos, &model.TableInfo{ + Name: model.NewCIStr("tables_1"), + }) + infos = append(infos, &model.TableInfo{ + Name: model.NewCIStr("tables_2"), + }) + infos = append(infos, &model.TableInfo{ + Name: model.NewCIStr("tables_3"), + }) + + se := &tidbSession{se: tk.Session()} + + tk.Session().SetValue(sessionctx.QueryString, "skip") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/ddl/RestoreBatchCreateTableEntryTooLarge", "return(0)")) + err := se.SplitBatchCreateTable(model.NewCIStr("test"), infos, ddl.AllocTableIDIf(func(ti *model.TableInfo) bool { + return true + })) + + require.True(t, kv.ErrEntryTooLarge.Equal(err)) + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/RestoreBatchCreateTableEntryTooLarge")) +} diff --git a/br/pkg/lightning/backend/local/BUILD.bazel b/br/pkg/lightning/backend/local/BUILD.bazel index c034e6bdb2b3c..9524ab5febc2b 100644 --- a/br/pkg/lightning/backend/local/BUILD.bazel +++ b/br/pkg/lightning/backend/local/BUILD.bazel @@ -103,6 +103,7 @@ go_test( "//br/pkg/lightning/glue", "//br/pkg/lightning/log", "//br/pkg/lightning/mydump", + "//br/pkg/lightning/worker", "//br/pkg/membuf", "//br/pkg/mock", "//br/pkg/pdutil", diff --git a/br/pkg/lightning/backend/local/engine_test.go b/br/pkg/lightning/backend/local/engine_test.go index c7ffe04b95285..eae0225bb519a 100644 --- a/br/pkg/lightning/backend/local/engine_test.go +++ b/br/pkg/lightning/backend/local/engine_test.go @@ -31,8 +31,17 @@ import ( "github.com/stretchr/testify/require" ) -func TestIngestSSTWithClosedEngine(t *testing.T) { +func makePebbleDB(t *testing.T, opt *pebble.Options) (*pebble.DB, string) { dir := t.TempDir() + db, err := pebble.Open(path.Join(dir, "test"), opt) + require.NoError(t, err) + tmpPath := filepath.Join(dir, "test.sst") + err = os.Mkdir(tmpPath, 0o755) + require.NoError(t, err) + return db, tmpPath +} + +func TestIngestSSTWithClosedEngine(t *testing.T) { opt := &pebble.Options{ MemTableSize: 1024 * 1024, MaxConcurrentCompactions: 16, @@ -41,11 +50,7 @@ func TestIngestSSTWithClosedEngine(t *testing.T) { DisableWAL: true, ReadOnly: false, } - db, err := pebble.Open(filepath.Join(dir, "test"), opt) - require.NoError(t, err) - tmpPath := filepath.Join(dir, "test.sst") - err = os.Mkdir(tmpPath, 0o755) - require.NoError(t, err) + db, tmpPath := makePebbleDB(t, opt) _, engineUUID := backend.MakeUUID("ww", 0) engineCtx, cancel := context.WithCancel(context.Background()) diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index 6cbf763037262..eea3a734b0176 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -92,6 +92,7 @@ const ( gRPCKeepAliveTime = 10 * time.Minute gRPCKeepAliveTimeout = 5 * time.Minute gRPCBackOffMaxDelay = 10 * time.Minute + writeStallSleepTime = 10 * time.Second // The max ranges count in a batch to split and scatter. maxBatchSplitRanges = 4096 @@ -383,6 +384,12 @@ type local struct { encBuilder backend.EncodingBuilder targetInfoGetter backend.TargetInfoGetter + + // When TiKV is in normal mode, ingesting too many SSTs will cause TiKV write stall. + // To avoid this, we should check write stall before ingesting SSTs. Note that, we + // must check both leader node and followers in client side, because followers will + // not check write stall as long as ingest command is accepted by leader. + shouldCheckWriteStall bool } func openDuplicateDB(storeDir string) (*pebble.DB, error) { @@ -506,6 +513,7 @@ func NewLocalBackend( logger: log.FromContext(ctx), encBuilder: NewEncodingBuilder(ctx), targetInfoGetter: NewTargetInfoGetter(tls, g, cfg.TiDB.PdAddr), + shouldCheckWriteStall: cfg.Cron.SwitchMode.Duration == 0, } if m, ok := metric.FromContext(ctx); ok { local.metrics = m @@ -1151,6 +1159,25 @@ func (local *local) Ingest(ctx context.Context, metas []*sst.SSTMeta, region *sp return resp, errors.Trace(err) } + if local.shouldCheckWriteStall { + for { + maybeWriteStall, err := local.checkWriteStall(ctx, region) + if err != nil { + return nil, err + } + if !maybeWriteStall { + break + } + log.FromContext(ctx).Warn("ingest maybe cause write stall, sleep and retry", + zap.Duration("duration", writeStallSleepTime)) + select { + case <-time.After(writeStallSleepTime): + case <-ctx.Done(): + return nil, errors.Trace(ctx.Err()) + } + } + } + req := &sst.MultiIngestRequest{ Context: reqCtx, Ssts: metas, @@ -1159,6 +1186,23 @@ func (local *local) Ingest(ctx context.Context, metas []*sst.SSTMeta, region *sp return resp, errors.Trace(err) } +func (local *local) checkWriteStall(ctx context.Context, region *split.RegionInfo) (bool, error) { + for _, peer := range region.Region.GetPeers() { + cli, err := local.getImportClient(ctx, peer.StoreId) + if err != nil { + return false, errors.Trace(err) + } + resp, err := cli.MultiIngest(ctx, &sst.MultiIngestRequest{}) + if err != nil { + return false, errors.Trace(err) + } + if resp.Error != nil && resp.Error.ServerIsBusy != nil { + return true, nil + } + } + return false, nil +} + func splitRangeBySizeProps(fullRange Range, sizeProps *sizeProperties, sizeLimit int64, keysLimit int64) []Range { ranges := make([]Range, 0, sizeProps.totalSize/uint64(sizeLimit)) curSize := uint64(0) diff --git a/br/pkg/lightning/backend/local/local_test.go b/br/pkg/lightning/backend/local/local_test.go index 9019ebaf6f62e..0913aa9b86801 100644 --- a/br/pkg/lightning/backend/local/local_test.go +++ b/br/pkg/lightning/backend/local/local_test.go @@ -18,10 +18,10 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "io" "math" "math/rand" - "os" "path/filepath" "sort" "strings" @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/br/pkg/lightning/worker" "github.com/pingcap/tidb/br/pkg/membuf" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/br/pkg/restore/split" @@ -248,8 +249,6 @@ func TestRangeProperties(t *testing.T) { } func TestRangePropertiesWithPebble(t *testing.T) { - dir := t.TempDir() - sizeDistance := uint64(500) keysDistance := uint64(20) opt := &pebble.Options{ @@ -270,8 +269,7 @@ func TestRangePropertiesWithPebble(t *testing.T) { }, }, } - db, err := pebble.Open(filepath.Join(dir, "test"), opt) - require.NoError(t, err) + db, _ := makePebbleDB(t, opt) defer db.Close() // local collector @@ -288,7 +286,7 @@ func TestRangePropertiesWithPebble(t *testing.T) { key := make([]byte, 8) valueLen := rand.Intn(50) binary.BigEndian.PutUint64(key, uint64(i*100+j)) - err = wb.Set(key, value[:valueLen], writeOpt) + err := wb.Set(key, value[:valueLen], writeOpt) require.NoError(t, err) err = collector.Add(pebble.InternalKey{UserKey: key, Trailer: pebble.InternalKeyKindSet}, value[:valueLen]) require.NoError(t, err) @@ -315,7 +313,6 @@ func TestRangePropertiesWithPebble(t *testing.T) { } func testLocalWriter(t *testing.T, needSort bool, partitialSort bool) { - dir := t.TempDir() opt := &pebble.Options{ MemTableSize: 1024 * 1024, MaxConcurrentCompactions: 16, @@ -324,12 +321,8 @@ func testLocalWriter(t *testing.T, needSort bool, partitialSort bool) { DisableWAL: true, ReadOnly: false, } - db, err := pebble.Open(filepath.Join(dir, "test"), opt) - require.NoError(t, err) + db, tmpPath := makePebbleDB(t, opt) defer db.Close() - tmpPath := filepath.Join(dir, "test.sst") - err = os.Mkdir(tmpPath, 0o755) - require.NoError(t, err) _, engineUUID := backend.MakeUUID("ww", 0) engineCtx, cancel := context.WithCancel(context.Background()) @@ -575,7 +568,6 @@ func (i testIngester) ingest([]*sstMeta) error { } func TestLocalIngestLoop(t *testing.T) { - dir := t.TempDir() opt := &pebble.Options{ MemTableSize: 1024 * 1024, MaxConcurrentCompactions: 16, @@ -584,18 +576,14 @@ func TestLocalIngestLoop(t *testing.T) { DisableWAL: true, ReadOnly: false, } - db, err := pebble.Open(filepath.Join(dir, "test"), opt) - require.NoError(t, err) + db, tmpPath := makePebbleDB(t, opt) defer db.Close() - tmpPath := filepath.Join(dir, "test.sst") - err = os.Mkdir(tmpPath, 0o755) - require.NoError(t, err) _, engineUUID := backend.MakeUUID("ww", 0) engineCtx, cancel := context.WithCancel(context.Background()) f := Engine{ db: db, UUID: engineUUID, - sstDir: "", + sstDir: tmpPath, ctx: engineCtx, cancel: cancel, sstMetasChan: make(chan metaOrFlush, 64), @@ -648,7 +636,7 @@ func TestLocalIngestLoop(t *testing.T) { wg.Wait() f.mutex.RLock() - err = f.flushEngineWithoutLock(engineCtx) + err := f.flushEngineWithoutLock(engineCtx) require.NoError(t, err) f.mutex.RUnlock() @@ -743,7 +731,6 @@ func TestFilterOverlapRange(t *testing.T) { } func testMergeSSTs(t *testing.T, kvs [][]common.KvPair, meta *sstMeta) { - dir := t.TempDir() opt := &pebble.Options{ MemTableSize: 1024 * 1024, MaxConcurrentCompactions: 16, @@ -752,12 +739,8 @@ func testMergeSSTs(t *testing.T, kvs [][]common.KvPair, meta *sstMeta) { DisableWAL: true, ReadOnly: false, } - db, err := pebble.Open(filepath.Join(dir, "test"), opt) - require.NoError(t, err) + db, tmpPath := makePebbleDB(t, opt) defer db.Close() - tmpPath := filepath.Join(dir, "test.sst") - err = os.Mkdir(tmpPath, 0o755) - require.NoError(t, err) _, engineUUID := backend.MakeUUID("ww", 0) engineCtx, cancel := context.WithCancel(context.Background()) @@ -848,49 +831,90 @@ func TestMergeSSTsDuplicated(t *testing.T) { type mockPdClient struct { pd.Client - stores []*metapb.Store + stores []*metapb.Store + regions []*pd.Region } func (c *mockPdClient) GetAllStores(ctx context.Context, opts ...pd.GetStoreOption) ([]*metapb.Store, error) { return c.stores, nil } +func (c *mockPdClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*pd.Region, error) { + return c.regions, nil +} + type mockGrpcErr struct{} func (e mockGrpcErr) GRPCStatus() *status.Status { - return status.New(codes.Unimplemented, "unimplmented") + return status.New(codes.Unimplemented, "unimplemented") } func (e mockGrpcErr) Error() string { - return "unimplmented" + return "unimplemented" } type mockImportClient struct { sst.ImportSSTClient store *metapb.Store + resp *sst.IngestResponse err error retry int cnt int multiIngestCheckFn func(s *metapb.Store) bool + apiInvokeRecorder map[string][]uint64 +} + +func newMockImportClient() *mockImportClient { + return &mockImportClient{ + multiIngestCheckFn: func(s *metapb.Store) bool { + return true + }, + } } func (c *mockImportClient) MultiIngest(context.Context, *sst.MultiIngestRequest, ...grpc.CallOption) (*sst.IngestResponse, error) { defer func() { c.cnt++ }() - if c.cnt < c.retry && c.err != nil { - return nil, c.err + if c.apiInvokeRecorder != nil { + c.apiInvokeRecorder["MultiIngest"] = append(c.apiInvokeRecorder["MultiIngest"], c.store.GetId()) + } + if c.cnt < c.retry && (c.err != nil || c.resp != nil) { + return c.resp, c.err } if !c.multiIngestCheckFn(c.store) { return nil, mockGrpcErr{} } - return nil, nil + return &sst.IngestResponse{}, nil +} + +type mockWriteClient struct { + sst.ImportSST_WriteClient + writeResp *sst.WriteResponse +} + +func (m mockWriteClient) Send(request *sst.WriteRequest) error { + return nil +} + +func (m mockWriteClient) CloseAndRecv() (*sst.WriteResponse, error) { + return m.writeResp, nil +} + +func (c *mockImportClient) Write(ctx context.Context, opts ...grpc.CallOption) (sst.ImportSST_WriteClient, error) { + if c.apiInvokeRecorder != nil { + c.apiInvokeRecorder["Write"] = append(c.apiInvokeRecorder["Write"], c.store.GetId()) + } + return mockWriteClient{writeResp: &sst.WriteResponse{Metas: []*sst.SSTMeta{ + {}, {}, {}, + }}}, nil } type mockImportClientFactory struct { - stores []*metapb.Store - createClientFn func(store *metapb.Store) sst.ImportSSTClient + stores []*metapb.Store + createClientFn func(store *metapb.Store) sst.ImportSSTClient + apiInvokeRecorder map[string][]uint64 } func (f *mockImportClientFactory) Create(_ context.Context, storeID uint64) (sst.ImportSSTClient, error) { @@ -899,7 +923,7 @@ func (f *mockImportClientFactory) Create(_ context.Context, storeID uint64) (sst return f.createClientFn(store), nil } } - return nil, errors.New("store not found") + return nil, fmt.Errorf("store %d not found", storeID) } func (f *mockImportClientFactory) Close() {} @@ -1231,3 +1255,75 @@ func TestLocalIsRetryableTiKVWriteError(t *testing.T) { require.True(t, l.isRetryableImportTiKVError(io.EOF)) require.True(t, l.isRetryableImportTiKVError(errors.Trace(io.EOF))) } + +func TestCheckPeersBusy(t *testing.T) { + ctx := context.Background() + pdCli := &mockPdClient{} + pdCtl := &pdutil.PdController{} + pdCtl.SetPDClient(pdCli) + + keys := [][]byte{[]byte(""), []byte("a"), []byte("b"), []byte("")} + splitCli := initTestSplitClient3Replica(keys, nil) + apiInvokeRecorder := map[string][]uint64{} + serverIsBusyResp := &sst.IngestResponse{ + Error: &errorpb.Error{ + ServerIsBusy: &errorpb.ServerIsBusy{}, + }} + + createTimeStore12 := 0 + local := &local{ + pdCtl: pdCtl, + splitCli: splitCli, + importClientFactory: &mockImportClientFactory{ + stores: []*metapb.Store{ + // region ["", "a") is not used, skip (1, 2, 3) + {Id: 11}, {Id: 12}, {Id: 13}, // region ["a", "b") + {Id: 21}, {Id: 22}, {Id: 23}, // region ["b", "") + }, + createClientFn: func(store *metapb.Store) sst.ImportSSTClient { + importCli := newMockImportClient() + importCli.store = store + importCli.apiInvokeRecorder = apiInvokeRecorder + if store.Id == 12 { + createTimeStore12++ + // the second time to checkWriteStall + if createTimeStore12 == 2 { + importCli.retry = 1 + importCli.resp = serverIsBusyResp + } + } + return importCli + }, + }, + logger: log.L(), + ingestConcurrency: worker.NewPool(ctx, 1, "ingest"), + writeLimiter: noopStoreWriteLimiter{}, + bufferPool: membuf.NewPool(), + supportMultiIngest: true, + shouldCheckWriteStall: true, + } + + db, tmpPath := makePebbleDB(t, nil) + _, engineUUID := backend.MakeUUID("ww", 0) + engineCtx, cancel := context.WithCancel(context.Background()) + f := &Engine{ + db: db, + UUID: engineUUID, + sstDir: tmpPath, + ctx: engineCtx, + cancel: cancel, + sstMetasChan: make(chan metaOrFlush, 64), + keyAdapter: noopKeyAdapter{}, + logger: log.L(), + } + err := f.db.Set([]byte("a"), []byte("a"), nil) + require.NoError(t, err) + err = f.db.Set([]byte("b"), []byte("b"), nil) + require.NoError(t, err) + err = local.writeAndIngestByRange(ctx, f, []byte("a"), []byte("c"), 0, 0) + require.NoError(t, err) + + require.Equal(t, []uint64{11, 12, 13, 21, 22, 23}, apiInvokeRecorder["Write"]) + // store 12 has a follower busy, so it will cause region peers (11, 12, 13) retry once + require.Equal(t, []uint64{11, 12, 11, 12, 13, 11, 21, 22, 23, 21}, apiInvokeRecorder["MultiIngest"]) +} diff --git a/br/pkg/lightning/backend/local/localhelper_test.go b/br/pkg/lightning/backend/local/localhelper_test.go index 6cbf7f2f14808..023fade304fae 100644 --- a/br/pkg/lightning/backend/local/localhelper_test.go +++ b/br/pkg/lightning/backend/local/localhelper_test.go @@ -47,7 +47,7 @@ func init() { splitRetryTimes = 2 } -type testClient struct { +type testSplitClient struct { mu sync.RWMutex stores map[uint64]*metapb.Store regions map[uint64]*split.RegionInfo @@ -57,17 +57,17 @@ type testClient struct { hook clientHook } -func newTestClient( +func newTestSplitClient( stores map[uint64]*metapb.Store, regions map[uint64]*split.RegionInfo, nextRegionID uint64, hook clientHook, -) *testClient { +) *testSplitClient { regionsInfo := &pdtypes.RegionTree{} for _, regionInfo := range regions { regionsInfo.SetRegion(pdtypes.NewRegionInfo(regionInfo.Region, regionInfo.Leader)) } - return &testClient{ + return &testSplitClient{ stores: stores, regions: regions, regionsInfo: regionsInfo, @@ -77,17 +77,17 @@ func newTestClient( } // ScatterRegions scatters regions in a batch. -func (c *testClient) ScatterRegions(ctx context.Context, regionInfo []*split.RegionInfo) error { +func (c *testSplitClient) ScatterRegions(ctx context.Context, regionInfo []*split.RegionInfo) error { return nil } -func (c *testClient) GetAllRegions() map[uint64]*split.RegionInfo { +func (c *testSplitClient) GetAllRegions() map[uint64]*split.RegionInfo { c.mu.RLock() defer c.mu.RUnlock() return c.regions } -func (c *testClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) { +func (c *testSplitClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) { c.mu.RLock() defer c.mu.RUnlock() store, ok := c.stores[storeID] @@ -97,19 +97,18 @@ func (c *testClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Stor return store, nil } -func (c *testClient) GetRegion(ctx context.Context, key []byte) (*split.RegionInfo, error) { +func (c *testSplitClient) GetRegion(ctx context.Context, key []byte) (*split.RegionInfo, error) { c.mu.RLock() defer c.mu.RUnlock() for _, region := range c.regions { - if bytes.Compare(key, region.Region.StartKey) >= 0 && - (len(region.Region.EndKey) == 0 || bytes.Compare(key, region.Region.EndKey) < 0) { + if bytes.Compare(key, region.Region.StartKey) >= 0 && beforeEnd(key, region.Region.EndKey) { return region, nil } } return nil, errors.Errorf("region not found: key=%s", string(key)) } -func (c *testClient) GetRegionByID(ctx context.Context, regionID uint64) (*split.RegionInfo, error) { +func (c *testSplitClient) GetRegionByID(ctx context.Context, regionID uint64) (*split.RegionInfo, error) { c.mu.RLock() defer c.mu.RUnlock() region, ok := c.regions[regionID] @@ -119,7 +118,7 @@ func (c *testClient) GetRegionByID(ctx context.Context, regionID uint64) (*split return region, nil } -func (c *testClient) SplitRegion( +func (c *testSplitClient) SplitRegion( ctx context.Context, regionInfo *split.RegionInfo, key []byte, @@ -130,7 +129,7 @@ func (c *testClient) SplitRegion( splitKey := codec.EncodeBytes([]byte{}, key) for _, region := range c.regions { if bytes.Compare(splitKey, region.Region.StartKey) >= 0 && - (len(region.Region.EndKey) == 0 || bytes.Compare(splitKey, region.Region.EndKey) < 0) { + beforeEnd(splitKey, region.Region.EndKey) { target = region } } @@ -159,7 +158,7 @@ func (c *testClient) SplitRegion( return newRegion, nil } -func (c *testClient) BatchSplitRegionsWithOrigin( +func (c *testSplitClient) BatchSplitRegionsWithOrigin( ctx context.Context, regionInfo *split.RegionInfo, keys [][]byte, ) (*split.RegionInfo, []*split.RegionInfo, error) { c.mu.Lock() @@ -234,24 +233,24 @@ func (c *testClient) BatchSplitRegionsWithOrigin( return target, newRegions, err } -func (c *testClient) BatchSplitRegions( +func (c *testSplitClient) BatchSplitRegions( ctx context.Context, regionInfo *split.RegionInfo, keys [][]byte, ) ([]*split.RegionInfo, error) { _, newRegions, err := c.BatchSplitRegionsWithOrigin(ctx, regionInfo, keys) return newRegions, err } -func (c *testClient) ScatterRegion(ctx context.Context, regionInfo *split.RegionInfo) error { +func (c *testSplitClient) ScatterRegion(ctx context.Context, regionInfo *split.RegionInfo) error { return nil } -func (c *testClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { +func (c *testSplitClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { return &pdpb.GetOperatorResponse{ Header: new(pdpb.ResponseHeader), }, nil } -func (c *testClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*split.RegionInfo, error) { +func (c *testSplitClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*split.RegionInfo, error) { if c.hook != nil { key, endKey, limit = c.hook.BeforeScanRegions(ctx, key, endKey, limit) } @@ -272,19 +271,19 @@ func (c *testClient) ScanRegions(ctx context.Context, key, endKey []byte, limit return regions, err } -func (c *testClient) GetPlacementRule(ctx context.Context, groupID, ruleID string) (r pdtypes.Rule, err error) { +func (c *testSplitClient) GetPlacementRule(ctx context.Context, groupID, ruleID string) (r pdtypes.Rule, err error) { return } -func (c *testClient) SetPlacementRule(ctx context.Context, rule pdtypes.Rule) error { +func (c *testSplitClient) SetPlacementRule(ctx context.Context, rule pdtypes.Rule) error { return nil } -func (c *testClient) DeletePlacementRule(ctx context.Context, groupID, ruleID string) error { +func (c *testSplitClient) DeletePlacementRule(ctx context.Context, groupID, ruleID string) error { return nil } -func (c *testClient) SetStoresLabel(ctx context.Context, stores []uint64, labelKey, labelValue string) error { +func (c *testSplitClient) SetStoresLabel(ctx context.Context, stores []uint64, labelKey, labelValue string) error { return nil } @@ -305,7 +304,7 @@ func cloneRegion(region *split.RegionInfo) *split.RegionInfo { // For keys ["", "aay", "bba", "bbh", "cca", ""], the key ranges of // regions are [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ). -func initTestClient(keys [][]byte, hook clientHook) *testClient { +func initTestSplitClient(keys [][]byte, hook clientHook) *testSplitClient { peers := make([]*metapb.Peer, 1) peers[0] = &metapb.Peer{ Id: 1, @@ -329,13 +328,56 @@ func initTestClient(keys [][]byte, hook clientHook) *testClient { EndKey: endKey, RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, }, + Leader: peers[0], } } stores := make(map[uint64]*metapb.Store) stores[1] = &metapb.Store{ Id: 1, } - return newTestClient(stores, regions, uint64(len(keys)), hook) + return newTestSplitClient(stores, regions, uint64(len(keys)), hook) +} + +// initTestSplitClient3Replica will create a client that each region has 3 replicas, and their IDs and StoreIDs are +// (1, 2, 3), (11, 12, 13), ... +// For keys ["", "aay", "bba", "bbh", "cca", ""], the key ranges of +// region ranges are [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ). +func initTestSplitClient3Replica(keys [][]byte, hook clientHook) *testSplitClient { + regions := make(map[uint64]*split.RegionInfo) + stores := make(map[uint64]*metapb.Store) + for i := uint64(1); i < uint64(len(keys)); i++ { + startKey := keys[i-1] + if len(startKey) != 0 { + startKey = codec.EncodeBytes([]byte{}, startKey) + } + endKey := keys[i] + if len(endKey) != 0 { + endKey = codec.EncodeBytes([]byte{}, endKey) + } + baseID := (i-1)*10 + 1 + peers := make([]*metapb.Peer, 3) + for j := 0; j < 3; j++ { + peers[j] = &metapb.Peer{ + Id: baseID + uint64(j), + StoreId: baseID + uint64(j), + } + } + + regions[baseID] = &split.RegionInfo{ + Region: &metapb.Region{ + Id: baseID, + Peers: peers, + StartKey: startKey, + EndKey: endKey, + RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, + }, + Leader: peers[0], + } + stores[baseID] = &metapb.Store{ + Id: baseID, + } + } + return newTestSplitClient(stores, regions, uint64(len(keys)), hook) } func checkRegionRanges(t *testing.T, regions []*split.RegionInfo, keys [][]byte) { @@ -376,7 +418,7 @@ func (h *noopHook) AfterScanRegions(res []*split.RegionInfo, err error) ([]*spli type batchSplitHook interface { setup(t *testing.T) func() - check(t *testing.T, cli *testClient) + check(t *testing.T, cli *testSplitClient) } type defaultHook struct{} @@ -392,7 +434,7 @@ func (d defaultHook) setup(t *testing.T) func() { } } -func (d defaultHook) check(t *testing.T, cli *testClient) { +func (d defaultHook) check(t *testing.T, cli *testSplitClient) { // so with a batch split size of 4, there will be 7 time batch split // 1. region: [aay, bba), keys: [b, ba, bb] // 2. region: [bbh, cca), keys: [bc, bd, be, bf] @@ -414,7 +456,7 @@ func doTestBatchSplitRegionByRanges(ctx context.Context, t *testing.T, hook clie defer deferFunc() keys := [][]byte{[]byte(""), []byte("aay"), []byte("bba"), []byte("bbh"), []byte("cca"), []byte("")} - client := initTestClient(keys, hook) + client := initTestSplitClient(keys, hook) local := &local{ splitCli: client, g: glue.NewExternalTiDBGlue(nil, mysql.ModeNone), @@ -479,7 +521,7 @@ func (h batchSizeHook) setup(t *testing.T) func() { } } -func (h batchSizeHook) check(t *testing.T, cli *testClient) { +func (h batchSizeHook) check(t *testing.T, cli *testSplitClient) { // so with a batch split key size of 6, there will be 9 time batch split // 1. region: [aay, bba), keys: [b, ba, bb] // 2. region: [bbh, cca), keys: [bc, bd, be] @@ -583,7 +625,7 @@ func TestSplitAndScatterRegionInBatches(t *testing.T) { defer deferFunc() keys := [][]byte{[]byte(""), []byte("a"), []byte("b"), []byte("")} - client := initTestClient(keys, nil) + client := initTestSplitClient(keys, nil) local := &local{ splitCli: client, g: glue.NewExternalTiDBGlue(nil, mysql.ModeNone), @@ -670,7 +712,7 @@ func doTestBatchSplitByRangesWithClusteredIndex(t *testing.T, hook clientHook) { keys = append(keys, key) } keys = append(keys, tableEndKey, []byte("")) - client := initTestClient(keys, hook) + client := initTestSplitClient(keys, hook) local := &local{ splitCli: client, g: glue.NewExternalTiDBGlue(nil, mysql.ModeNone), diff --git a/br/pkg/restore/import_retry.go b/br/pkg/restore/import_retry.go index 6f3b9fc1cca53..4a6bdf1b8afcf 100644 --- a/br/pkg/restore/import_retry.go +++ b/br/pkg/restore/import_retry.go @@ -4,9 +4,11 @@ package restore import ( "context" + "strings" "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -85,6 +87,21 @@ func (o *OverRegionsInRangeController) tryFindLeader(ctx context.Context, region // handleInRegionError handles the error happens internal in the region. Update the region info, and perform a suitable backoff. func (o *OverRegionsInRangeController) handleInRegionError(ctx context.Context, result RPCResult, region *split.RegionInfo) (cont bool) { + if result.StoreError.GetServerIsBusy() != nil { + if strings.Contains(result.StoreError.GetMessage(), "memory is limited") { + sleepDuration := 15 * time.Second + + failpoint.Inject("hint-memory-is-limited", func(val failpoint.Value) { + if val.(bool) { + logutil.CL(ctx).Debug("failpoint hint-memory-is-limited injected.") + sleepDuration = 100 * time.Microsecond + } + }) + time.Sleep(sleepDuration) + return true + } + } + if nl := result.StoreError.GetNotLeader(); nl != nil { if nl.Leader != nil { region.Leader = nl.Leader diff --git a/br/pkg/restore/import_retry_test.go b/br/pkg/restore/import_retry_test.go index 6f3d8f490ef13..4e885657f998f 100644 --- a/br/pkg/restore/import_retry_test.go +++ b/br/pkg/restore/import_retry_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/import_sstpb" @@ -163,6 +164,48 @@ func TestServerIsBusy(t *testing.T) { require.NoError(t, err) assertRegions(t, idEqualsTo2Regions, "aay", "bba") assertRegions(t, meetRegions, "", "aay", "bba", "bbh", "cca", "") + require.Equal(t, rs.RetryTimes(), 1) +} + +func TestServerIsBusyWithMemoryIsLimited(t *testing.T) { + _ = failpoint.Enable("github.com/pingcap/tidb/br/pkg/restore/hint-memory-is-limited", "return(true)") + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/br/pkg/restore/hint-memory-is-limited") + }() + + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(2, 0, 0) + ctl := restore.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctx := context.Background() + + serverIsBusy := errorpb.Error{ + Message: "memory is limited", + ServerIsBusy: &errorpb.ServerIsBusy{ + Reason: "", + }, + } + // record the regions we didn't touch. + meetRegions := []*split.RegionInfo{} + // record all regions we meet with id == 2. + idEqualsTo2Regions := []*split.RegionInfo{} + theFirstRun := true + err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + if theFirstRun && r.Region.Id == 2 { + idEqualsTo2Regions = append(idEqualsTo2Regions, r) + theFirstRun = false + return restore.RPCResult{ + StoreError: &serverIsBusy, + } + } + meetRegions = append(meetRegions, r) + return restore.RPCResultOK() + }) + + require.NoError(t, err) + assertRegions(t, idEqualsTo2Regions, "aay", "bba") + assertRegions(t, meetRegions, "", "aay", "bba", "bbh", "cca", "") + require.Equal(t, rs.RetryTimes(), 0) } func printRegion(name string, infos []*split.RegionInfo) { diff --git a/br/pkg/streamhelper/advancer_daemon.go b/br/pkg/streamhelper/advancer_daemon.go index 10f43e105ccbe..4e3b68eb3fbf5 100644 --- a/br/pkg/streamhelper/advancer_daemon.go +++ b/br/pkg/streamhelper/advancer_daemon.go @@ -26,10 +26,14 @@ func (c *CheckpointAdvancer) OnTick(ctx context.Context) (err error) { return c.tick(ctx) } -// OnStart implements daemon.Interface. +// OnStart implements daemon.Interface, which will be called when log backup service starts. func (c *CheckpointAdvancer) OnStart(ctx context.Context) { - metrics.AdvancerOwner.Set(1.0) c.StartTaskListener(ctx) +} + +// OnBecomeOwner implements daemon.Interface. If the tidb-server become owner, this function will be called. +func (c *CheckpointAdvancer) OnBecomeOwner(ctx context.Context) { + metrics.AdvancerOwner.Set(1.0) c.spawnSubscriptionHandler(ctx) go func() { <-ctx.Done() diff --git a/br/pkg/streamhelper/daemon/interface.go b/br/pkg/streamhelper/daemon/interface.go index 544d67b153a0a..9f651d9488fb6 100644 --- a/br/pkg/streamhelper/daemon/interface.go +++ b/br/pkg/streamhelper/daemon/interface.go @@ -6,9 +6,11 @@ import "context" // Interface describes the lifetime hook of a daemon application. type Interface interface { - // OnStart would be called once become the owner. - // The context passed in would be canceled once it is no more the owner. + // OnStart start the service whatever the tidb-server is owner or not. OnStart(ctx context.Context) + // OnBecomeOwner would be called once become the owner. + // The context passed in would be canceled once it is no more the owner. + OnBecomeOwner(ctx context.Context) // OnTick would be called periodically. // The error can be recorded. OnTick(ctx context.Context) error diff --git a/br/pkg/streamhelper/daemon/owner_daemon.go b/br/pkg/streamhelper/daemon/owner_daemon.go index 3f14315957c43..533e38b7296c1 100644 --- a/br/pkg/streamhelper/daemon/owner_daemon.go +++ b/br/pkg/streamhelper/daemon/owner_daemon.go @@ -55,7 +55,7 @@ func (od *OwnerDaemon) ownerTick(ctx context.Context) { od.cancel = cancel log.Info("daemon became owner", zap.String("id", od.manager.ID()), zap.String("daemon-id", od.daemon.Name())) // Note: maybe save the context so we can cancel the tick when we are not owner? - od.daemon.OnStart(cx) + od.daemon.OnBecomeOwner(cx) } // Tick anyway. @@ -72,6 +72,10 @@ func (od *OwnerDaemon) Begin(ctx context.Context) (func(), error) { return nil, err } + // start the service. + od.daemon.OnStart(ctx) + + // tick starts. tick := time.NewTicker(od.tickInterval) loop := func() { log.Info("begin running daemon", zap.String("id", od.manager.ID()), zap.String("daemon-id", od.daemon.Name())) diff --git a/br/pkg/streamhelper/daemon/owner_daemon_test.go b/br/pkg/streamhelper/daemon/owner_daemon_test.go index 74251d0b410a1..7ae6ebf38e59e 100644 --- a/br/pkg/streamhelper/daemon/owner_daemon_test.go +++ b/br/pkg/streamhelper/daemon/owner_daemon_test.go @@ -16,7 +16,8 @@ import ( type anApp struct { sync.Mutex - begun bool + serviceStart bool + begun bool tickingMessenger chan struct{} tickingMessengerOnce *sync.Once @@ -33,9 +34,14 @@ func newTestApp(t *testing.T) *anApp { } } -// OnStart would be called once become the owner. -// The context passed in would be canceled once it is no more the owner. +// OnStart implements daemon.Interface. func (a *anApp) OnStart(ctx context.Context) { + a.serviceStart = true +} + +// OOnBecomeOwner would be called once become the owner. +// The context passed in would be canceled once it is no more the owner. +func (a *anApp) OnBecomeOwner(ctx context.Context) { a.Lock() defer a.Unlock() if a.begun { @@ -87,6 +93,10 @@ func (a *anApp) Running() bool { return a.begun } +func (a *anApp) AssertService(req *require.Assertions, serviceStart bool) { + req.True(a.serviceStart == serviceStart) +} + func (a *anApp) AssertTick(timeout time.Duration) { a.Lock() messenger := a.tickingMessenger @@ -129,8 +139,10 @@ func TestDaemon(t *testing.T) { ow := owner.NewMockManager(ctx, "owner_daemon_test") d := daemon.New(app, ow, 100*time.Millisecond) + app.AssertService(req, false) f, err := d.Begin(ctx) req.NoError(err) + app.AssertService(req, true) go f() app.AssertStart(1 * time.Second) app.AssertTick(1 * time.Second) diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 601897883e727..efc8b85298ec0 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -512,10 +512,7 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf // according to https://github.com/pingcap/tidb/issues/34167. // we should get the real config from tikv to adapt the dynamic region. httpCli := httputil.NewClient(mgr.GetTLSConfig()) - mergeRegionSize, mergeRegionCount, err = mgr.GetMergeRegionSizeAndCount(ctx, httpCli) - if err != nil { - return errors.Trace(err) - } + mergeRegionSize, mergeRegionCount = mgr.GetMergeRegionSizeAndCount(ctx, httpCli) } keepaliveCfg.PermitWithoutStream = true diff --git a/br/pkg/task/restore_raw.go b/br/pkg/task/restore_raw.go index 6c15cd9989512..7b80ac18b4d87 100644 --- a/br/pkg/task/restore_raw.go +++ b/br/pkg/task/restore_raw.go @@ -80,10 +80,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR // according to https://github.com/pingcap/tidb/issues/34167. // we should get the real config from tikv to adapt the dynamic region. httpCli := httputil.NewClient(mgr.GetTLSConfig()) - mergeRegionSize, mergeRegionCount, err = mgr.GetMergeRegionSizeAndCount(ctx, httpCli) - if err != nil { - return errors.Trace(err) - } + mergeRegionSize, mergeRegionCount = mgr.GetMergeRegionSizeAndCount(ctx, httpCli) } keepaliveCfg := GetKeepalive(&cfg.Config) diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index c3bcc629183d5..ebf579ba0fb9d 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -54,6 +54,7 @@ go_library( "@org_golang_google_grpc//status", "@org_golang_x_net//http/httpproxy", "@org_golang_x_sync//errgroup", + "@org_uber_go_atomic//:atomic", "@org_uber_go_multierr//:multierr", "@org_uber_go_zap//:zap", "@org_uber_go_zap//zapcore", diff --git a/br/pkg/utils/backoff.go b/br/pkg/utils/backoff.go index bff2490b56650..08df56c1c1e53 100644 --- a/br/pkg/utils/backoff.go +++ b/br/pkg/utils/backoff.go @@ -82,6 +82,12 @@ func (rs *RetryState) RecordRetry() { rs.retryTimes++ } +// RetryTimes returns the retry times. +// usage: unit test. +func (rs *RetryState) RetryTimes() int { + return rs.retryTimes +} + // Attempt implements the `Backoffer`. // TODO: Maybe use this to replace the `exponentialBackoffer` (which is nearly homomorphic to this)? func (rs *RetryState) Attempt() int { diff --git a/br/pkg/utils/db.go b/br/pkg/utils/db.go index 9574c06670573..701379f5aa67b 100644 --- a/br/pkg/utils/db.go +++ b/br/pkg/utils/db.go @@ -6,13 +6,13 @@ import ( "context" "database/sql" "strings" - "sync" "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -25,8 +25,7 @@ var ( _ DBExecutor = &sql.DB{} _ DBExecutor = &sql.Conn{} - LogBackupTaskMutex sync.Mutex - logBackupTaskCount int + logBackupTaskCount = atomic.NewInt32(0) ) // QueryExecutor is a interface for exec query @@ -134,26 +133,24 @@ func SetGcRatio(ctx sqlexec.RestrictedSQLExecutor, ratio string) error { // LogBackupTaskCountInc increases the count of log backup task. func LogBackupTaskCountInc() { - LogBackupTaskMutex.Lock() - logBackupTaskCount++ - LogBackupTaskMutex.Unlock() + logBackupTaskCount.Inc() + log.Info("inc log backup task", zap.Int32("count", logBackupTaskCount.Load())) } // LogBackupTaskCountDec decreases the count of log backup task. func LogBackupTaskCountDec() { - LogBackupTaskMutex.Lock() - logBackupTaskCount-- - LogBackupTaskMutex.Unlock() + logBackupTaskCount.Dec() + log.Info("dec log backup task", zap.Int32("count", logBackupTaskCount.Load())) } // CheckLogBackupTaskExist checks that whether log-backup is existed. func CheckLogBackupTaskExist() bool { - return logBackupTaskCount > 0 + return logBackupTaskCount.Load() > 0 } // IsLogBackupInUse checks the log backup task existed. func IsLogBackupInUse(ctx sessionctx.Context) bool { - return CheckLogBackupEnabled(ctx) && CheckLogBackupTaskExist() + return CheckLogBackupTaskExist() } // GetTidbNewCollationEnabled returns the variable name of NewCollationEnabled. diff --git a/br/tests/br_restore_log_task_enable/run.sh b/br/tests/br_restore_log_task_enable/run.sh index 923f8fe7c2b33..5525123f74b54 100644 --- a/br/tests/br_restore_log_task_enable/run.sh +++ b/br/tests/br_restore_log_task_enable/run.sh @@ -20,6 +20,10 @@ TABLE="usertable" # start log task run_br log start --task-name 1234 -s "local://$TEST_DIR/$DB/log" --pd $PD_ADDR +if ! grep -i "inc log backup task" "$TEST_DIR/tidb.log"; then + echo "TEST: [$TEST_NAME] log start failed!" + exit 1 +fi run_sql "CREATE DATABASE $DB;" run_sql "CREATE TABLE $DB.$TABLE (id int);" @@ -47,7 +51,12 @@ run_br restore full -s "local://$TEST_DIR/$DB/full" --pd $PD_ADDR && exit 1 run_br restore point -s "local://$TEST_DIR/$DB/log" --pd $PD_ADDR && exit 1 # stop log task -run_br log stop --task-name 1234 --pd $PD_ADDR +unset BR_LOG_TO_TERM +run_br log stop --task-name 1234 --pd $PD_ADDR +if ! grep -i "dec log backup task" "$TEST_DIR/tidb.log"; then + echo "TEST: [$TEST_NAME] log stop failed!" + exit 1 +fi # restore full (should be success) run_br restore full -s "local://$TEST_DIR/$DB/full" --pd $PD_ADDR diff --git a/br/tests/lightning_reload_cert/run.sh b/br/tests/lightning_reload_cert/run.sh index be0c5ff40421e..e06ef8d7fbf51 100644 --- a/br/tests/lightning_reload_cert/run.sh +++ b/br/tests/lightning_reload_cert/run.sh @@ -29,7 +29,7 @@ shpid="$!" sleep 15 ok=0 for _ in {0..60}; do - if grep -Fq "connection error" "$TEST_DIR"/lightning.log; then + if grep -Fq "connection closed before server preface received" "$TEST_DIR"/lightning.log; then ok=1 break fi diff --git a/cmd/explaintest/r/index_merge.result b/cmd/explaintest/r/index_merge.result index 0233dbdb55f52..3ab63b5da8306 100644 --- a/cmd/explaintest/r/index_merge.result +++ b/cmd/explaintest/r/index_merge.result @@ -391,14 +391,14 @@ Delete_11 N/A root N/A └─SelectLock_17 4056.68 root for update 0 └─HashJoin_33 4056.68 root inner join, equal:[eq(test.t1.c1, test.t1.c1)] ├─HashAgg_36(Build) 3245.34 root group by:test.t1.c1, funcs:firstrow(test.t1.c1)->test.t1.c1 - │ └─IndexMerge_41 2248.30 root type: union - │ ├─IndexRangeScan_37(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo - │ ├─IndexRangeScan_38(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo - │ └─Selection_40(Probe) 2248.30 cop[tikv] not(isnull(test.t1.c1)), or(lt(test.t1.c1, 10), and(lt(test.t1.c2, 10), lt(test.t1.c3, 10))) - │ └─TableRowIDScan_39 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo - └─TableReader_44(Probe) 9990.00 root data:Selection_43 - └─Selection_43 9990.00 cop[tikv] not(isnull(test.t1.c1)) - └─TableFullScan_42 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + │ └─IndexMerge_45 2248.30 root type: union + │ ├─IndexRangeScan_41(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo + │ ├─IndexRangeScan_42(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo + │ └─Selection_44(Probe) 2248.30 cop[tikv] not(isnull(test.t1.c1)), or(lt(test.t1.c1, 10), and(lt(test.t1.c2, 10), lt(test.t1.c3, 10))) + │ └─TableRowIDScan_43 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo + └─TableReader_48(Probe) 9990.00 root data:Selection_47 + └─Selection_47 9990.00 cop[tikv] not(isnull(test.t1.c1)) + └─TableFullScan_46 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo delete from t1 where c1 in (select /*+ use_index_merge(t1) */ c1 from t1 where c1 < 10 or c2 < 10 and c3 < 10) order by 1; select * from t1; c1 c2 c3 @@ -409,14 +409,14 @@ Update_10 N/A root N/A └─SelectLock_14 4056.68 root for update 0 └─HashJoin_30 4056.68 root inner join, equal:[eq(test.t1.c1, test.t1.c1)] ├─HashAgg_33(Build) 3245.34 root group by:test.t1.c1, funcs:firstrow(test.t1.c1)->test.t1.c1 - │ └─IndexMerge_38 2248.30 root type: union - │ ├─IndexRangeScan_34(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo - │ ├─IndexRangeScan_35(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo - │ └─Selection_37(Probe) 2248.30 cop[tikv] not(isnull(test.t1.c1)), or(lt(test.t1.c1, 10), and(lt(test.t1.c2, 10), lt(test.t1.c3, 10))) - │ └─TableRowIDScan_36 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo - └─TableReader_41(Probe) 9990.00 root data:Selection_40 - └─Selection_40 9990.00 cop[tikv] not(isnull(test.t1.c1)) - └─TableFullScan_39 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + │ └─IndexMerge_42 2248.30 root type: union + │ ├─IndexRangeScan_38(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo + │ ├─IndexRangeScan_39(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo + │ └─Selection_41(Probe) 2248.30 cop[tikv] not(isnull(test.t1.c1)), or(lt(test.t1.c1, 10), and(lt(test.t1.c2, 10), lt(test.t1.c3, 10))) + │ └─TableRowIDScan_40 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo + └─TableReader_45(Probe) 9990.00 root data:Selection_44 + └─Selection_44 9990.00 cop[tikv] not(isnull(test.t1.c1)) + └─TableFullScan_43 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo update t1 set c1 = 100, c2 = 100, c3 = 100 where c1 in (select /*+ use_index_merge(t1) */ c1 from t1 where c1 < 10 or c2 < 10 and c3 < 10); select * from t1; c1 c2 c3 @@ -471,11 +471,11 @@ insert into t1 values(1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4), (5, 5, 5); explain select /*+ use_index_merge(t1) */ * from t1 where (c1 < 10 or c2 < 10) and c3 < 10 order by 1 limit 1 offset 2; id estRows task access object operator info TopN_10 1.00 root test.t1.c1, offset:2, count:1 -└─IndexMerge_19 1841.86 root type: union - ├─IndexRangeScan_15(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo - ├─IndexRangeScan_16(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo - └─Selection_18(Probe) 1841.86 cop[tikv] lt(test.t1.c3, 10) - └─TableRowIDScan_17 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo +└─IndexMerge_23 1841.86 root type: union + ├─IndexRangeScan_19(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo + ├─IndexRangeScan_20(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo + └─Selection_22(Probe) 1841.86 cop[tikv] lt(test.t1.c3, 10) + └─TableRowIDScan_21 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo select /*+ use_index_merge(t1) */ * from t1 where (c1 < 10 or c2 < 10) and c3 < 10 order by 1 limit 1 offset 2; c1 c2 c3 3 3 3 @@ -483,13 +483,13 @@ c1 c2 c3 explain select /*+ use_index_merge(t1) */ sum(c1) from t1 where (c1 < 10 or c2 < 10) and c3 < 10 group by c1 order by 1; id estRows task access object operator info Sort_6 1473.49 root Column#5 -└─HashAgg_11 1473.49 root group by:Column#10, funcs:sum(Column#9)->Column#5 - └─Projection_18 1841.86 root cast(test.t1.c1, decimal(10,0) BINARY)->Column#9, test.t1.c1 - └─IndexMerge_16 1841.86 root type: union - ├─IndexRangeScan_12(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo - ├─IndexRangeScan_13(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo - └─Selection_15(Probe) 1841.86 cop[tikv] lt(test.t1.c3, 10) - └─TableRowIDScan_14 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo +└─HashAgg_11 1473.49 root group by:Column#13, funcs:sum(Column#12)->Column#5 + └─Projection_22 1841.86 root cast(test.t1.c1, decimal(10,0) BINARY)->Column#12, test.t1.c1 + └─IndexMerge_20 1841.86 root type: union + ├─IndexRangeScan_16(Build) 3323.33 cop[tikv] table:t1, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo + ├─IndexRangeScan_17(Build) 3323.33 cop[tikv] table:t1, index:c2(c2) range:[-inf,10), keep order:false, stats:pseudo + └─Selection_19(Probe) 1841.86 cop[tikv] lt(test.t1.c3, 10) + └─TableRowIDScan_18 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo select /*+ use_index_merge(t1) */ sum(c1) from t1 where (c1 < 10 or c2 < 10) and c3 < 10 group by c1 order by 1; sum(c1) 1 @@ -536,8 +536,8 @@ Sort_16 1841.86 root test.t1.c1 │ └─Selection_25(Probe) 1841.86 cop[tikv] lt(test.t1.c3, 10) │ └─TableRowIDScan_24 5542.21 cop[tikv] table:t1 keep order:false, stats:pseudo └─TopN_29(Probe) 1841.86 root test.t2.c1, offset:2, count:1 - └─HashAgg_36 4900166.23 root group by:Column#21, funcs:avg(Column#19)->Column#9, funcs:firstrow(Column#20)->test.t2.c1 - └─Projection_48 6125207.79 root cast(test.t2.c1, decimal(10,0) BINARY)->Column#19, test.t2.c1, test.t2.c1 + └─HashAgg_35 4900166.23 root group by:Column#24, funcs:avg(Column#22)->Column#9, funcs:firstrow(Column#23)->test.t2.c1 + └─Projection_53 6125207.79 root cast(test.t2.c1, decimal(10,0) BINARY)->Column#22, test.t2.c1, test.t2.c1 └─IndexMerge_41 6125207.79 root type: union ├─Selection_38(Build) 6121.12 cop[tikv] eq(test.t1.c1, test.t2.c1) │ └─IndexRangeScan_37 6121120.92 cop[tikv] table:t2, index:c1(c1) range:[-inf,10), keep order:false, stats:pseudo diff --git a/config/config.go b/config/config.go index 93960bb46d199..f4b353156698b 100644 --- a/config/config.go +++ b/config/config.go @@ -974,7 +974,7 @@ var defaultConf = Config{ }, Experimental: Experimental{}, EnableCollectExecutionInfo: true, - EnableTelemetry: true, + EnableTelemetry: false, Labels: make(map[string]string), EnableGlobalIndex: false, Security: Security{ diff --git a/config/config.toml.example b/config/config.toml.example index 588379f204602..2d3b1acd0b29c 100644 --- a/config/config.toml.example +++ b/config/config.toml.example @@ -97,7 +97,7 @@ skip-register-to-dashboard = false # When enabled, usage data (for example, instance versions) will be reported to PingCAP periodically for user experience analytics. # If this config is set to `false` on all TiDB servers, telemetry will be always disabled regardless of the value of the global variable `tidb_enable_telemetry`. # See PingCAP privacy policy for details: https://pingcap.com/en/privacy-policy/ -enable-telemetry = true +enable-telemetry = false # deprecate-integer-display-length is used to be compatible with MySQL 8.0 in which the integer declared with display length will be returned with # a warning like `Integer display width is deprecated and will be removed in a future release`. diff --git a/config/config_test.go b/config/config_test.go index 9a6d12a284817..f4909ee15c23f 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -858,7 +858,7 @@ history-size=100`) require.NoError(t, err) require.NoError(t, f.Sync()) require.NoError(t, conf.Load(configFile)) - require.True(t, conf.EnableTelemetry) + require.False(t, conf.EnableTelemetry) _, err = f.WriteString(` enable-table-lock = true @@ -866,15 +866,15 @@ enable-table-lock = true require.NoError(t, err) require.NoError(t, f.Sync()) require.NoError(t, conf.Load(configFile)) - require.True(t, conf.EnableTelemetry) + require.False(t, conf.EnableTelemetry) _, err = f.WriteString(` -enable-telemetry = false +enable-telemetry = true `) require.NoError(t, err) require.NoError(t, f.Sync()) require.NoError(t, conf.Load(configFile)) - require.False(t, conf.EnableTelemetry) + require.True(t, conf.EnableTelemetry) _, err = f.WriteString(` [security] diff --git a/ddl/backfilling.go b/ddl/backfilling.go index 9873289815470..192840f35f5a5 100644 --- a/ddl/backfilling.go +++ b/ddl/backfilling.go @@ -867,7 +867,8 @@ func (dc *ddlCtx) writePhysicalTableRecord(sessPool *sessionPool, t table.Physic if len(remains) > 0 { startKey = remains[0].StartKey } else { - startKey = kvRanges[len(kvRanges)-1].EndKey + rangeEndKey := kvRanges[len(kvRanges)-1].EndKey + startKey = rangeEndKey.Next() } if startKey.Cmp(endKey) >= 0 { break diff --git a/ddl/column.go b/ddl/column.go index ec5aeb90157be..dc9bfbcafaa5a 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -364,7 +364,8 @@ func needChangeColumnData(oldCol, newCol *model.ColumnInfo) bool { toUnsigned := mysql.HasUnsignedFlag(newCol.GetFlag()) originUnsigned := mysql.HasUnsignedFlag(oldCol.GetFlag()) needTruncationOrToggleSign := func() bool { - return (newCol.GetFlen() > 0 && newCol.GetFlen() < oldCol.GetFlen()) || (toUnsigned != originUnsigned) + return (newCol.GetFlen() > 0 && (newCol.GetFlen() < oldCol.GetFlen() || newCol.GetDecimal() < oldCol.GetDecimal())) || + (toUnsigned != originUnsigned) } // Ignore the potential max display length represented by integer's flen, use default flen instead. defaultOldColFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(oldCol.GetType()) diff --git a/ddl/column_type_change_test.go b/ddl/column_type_change_test.go index e93c58225088a..04cd0a40508cf 100644 --- a/ddl/column_type_change_test.go +++ b/ddl/column_type_change_test.go @@ -2321,11 +2321,15 @@ func TestColumnTypeChangeBetweenFloatAndDouble(t *testing.T) { prepare := func(createTableStmt string) { tk.MustExec("drop table if exists t;") tk.MustExec(createTableStmt) - tk.MustExec("insert into t values (36.4), (24.1);") + tk.MustExec("insert into t values (36.43), (24.1);") } prepare("create table t (a float(6,2));") tk.MustExec("alter table t modify a double(6,2)") + tk.MustQuery("select a from t;").Check(testkit.Rows("36.43", "24.1")) + + prepare("create table t (a float(6,2));") + tk.MustExec("alter table t modify a float(6,1)") tk.MustQuery("select a from t;").Check(testkit.Rows("36.4", "24.1")) prepare("create table t (a double(6,2));") diff --git a/ddl/db_test.go b/ddl/db_test.go index edc891ad16ccf..d36de3a425025 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -1585,11 +1585,13 @@ func TestLogAndShowSlowLog(t *testing.T) { } func TestReportingMinStartTimestamp(t *testing.T) { - store, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease) - tk := testkit.NewTestKit(t, store) - se := tk.Session() + _, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease) infoSyncer := dom.InfoSyncer() + sm := &testkit.MockSessionManager{ + PS: make([]*util.ProcessInfo, 0), + } + infoSyncer.SetSessionManager(sm) beforeTS := oracle.GoTimeToTS(time.Now()) infoSyncer.ReportMinStartTS(dom.Store()) afterTS := oracle.GoTimeToTS(time.Now()) @@ -1598,21 +1600,13 @@ func TestReportingMinStartTimestamp(t *testing.T) { now := time.Now() validTS := oracle.GoTimeToLowerLimitStartTS(now.Add(time.Minute), tikv.MaxTxnTimeUse) lowerLimit := oracle.GoTimeToLowerLimitStartTS(now, tikv.MaxTxnTimeUse) - sm := se.GetSessionManager().(*testkit.MockSessionManager) sm.PS = []*util.ProcessInfo{ - {CurTxnStartTS: 0, ProtectedTSList: &se.GetSessionVars().ProtectedTSList}, - {CurTxnStartTS: math.MaxUint64, ProtectedTSList: &se.GetSessionVars().ProtectedTSList}, - {CurTxnStartTS: lowerLimit, ProtectedTSList: &se.GetSessionVars().ProtectedTSList}, - {CurTxnStartTS: validTS, ProtectedTSList: &se.GetSessionVars().ProtectedTSList}, + {CurTxnStartTS: 0}, + {CurTxnStartTS: math.MaxUint64}, + {CurTxnStartTS: lowerLimit}, + {CurTxnStartTS: validTS}, } - infoSyncer.ReportMinStartTS(dom.Store()) - require.Equal(t, validTS, infoSyncer.GetMinStartTS()) - - unhold := se.GetSessionVars().ProtectedTSList.HoldTS(validTS - 1) - infoSyncer.ReportMinStartTS(dom.Store()) - require.Equal(t, validTS-1, infoSyncer.GetMinStartTS()) - - unhold() + infoSyncer.SetSessionManager(sm) infoSyncer.ReportMinStartTS(dom.Store()) require.Equal(t, validTS, infoSyncer.GetMinStartTS()) } diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 19f1ce98dde73..37af43658d454 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -2548,6 +2548,12 @@ func (d *ddl) BatchCreateTableWithInfo(ctx sessionctx.Context, infos []*model.TableInfo, cs ...CreateTableWithInfoConfigurier, ) error { + failpoint.Inject("RestoreBatchCreateTableEntryTooLarge", func(val failpoint.Value) { + injectBatchSize := val.(int) + if len(infos) > injectBatchSize { + failpoint.Return(kv.ErrEntryTooLarge) + } + }) c := GetCreateTableWithInfoConfig(cs) jobs := &model.Job{ diff --git a/ddl/index.go b/ddl/index.go index 7ddd2d3e2cc2f..bb83c5c5cacc2 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -817,8 +817,14 @@ func doReorgWorkForCreateIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Jo } done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, indexInfo, false) if err != nil { + if common.ErrFoundDuplicateKeys.Equal(err) { + err = convertToKeyExistsErr(err, indexInfo, tbl.Meta()) + logutil.BgLogger().Warn("[ddl] found duplicate key, convert job to rollback", zap.String("job", job.String()), zap.Error(err)) + ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), indexInfo, err) + } else { + err = tryFallbackToTxnMerge(job, err) + } ingest.LitBackCtxMgr.Unregister(job.ID) - err = tryFallbackToTxnMerge(job, err) return false, ver, errors.Trace(err) } if !done { @@ -826,11 +832,11 @@ func doReorgWorkForCreateIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Jo } err = bc.FinishImport(indexInfo.ID, indexInfo.Unique, tbl) if err != nil { - if kv.ErrKeyExists.Equal(err) || common.ErrFoundDuplicateKeys.Equal(err) { + if common.ErrFoundDuplicateKeys.Equal(err) { + err = convertToKeyExistsErr(err, indexInfo, tbl.Meta()) + } + if kv.ErrKeyExists.Equal(err) { logutil.BgLogger().Warn("[ddl] import index duplicate key, convert job to rollback", zap.String("job", job.String()), zap.Error(err)) - if common.ErrFoundDuplicateKeys.Equal(err) { - err = convertToKeyExistsErr(err, indexInfo, tbl.Meta()) - } ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), indexInfo, err) } else { logutil.BgLogger().Warn("[ddl] lightning import error", zap.Error(err)) diff --git a/ddl/index_merge_tmp.go b/ddl/index_merge_tmp.go index b7fca70b8ed34..0150f0fb42b4c 100644 --- a/ddl/index_merge_tmp.go +++ b/ddl/index_merge_tmp.go @@ -223,7 +223,6 @@ func (w *mergeIndexWorker) fetchTempIndexVals(txn kv.Transaction, taskRange reor oprStartTime := startTime idxPrefix := w.table.IndexPrefix() var lastKey kv.Key - isCommonHandle := w.table.Meta().IsCommonHandle err := iterateSnapshotKeys(w.reorgInfo.d.jobContext(w.reorgInfo.Job), w.sessCtx.GetStore(), w.priority, idxPrefix, txn.StartTS(), taskRange.startKey, taskRange.endKey, func(_ kv.Handle, indexKey kv.Key, rawValue []byte) (more bool, err error) { oprEndTime := time.Now() @@ -240,10 +239,15 @@ func (w *mergeIndexWorker) fetchTempIndexVals(txn kv.Transaction, taskRange reor return false, nil } - tempIdxVal, err := tablecodec.DecodeTempIndexValue(rawValue, isCommonHandle) + tempIdxVal, err := tablecodec.DecodeTempIndexValue(rawValue) if err != nil { return false, err } + tempIdxVal, err = decodeTempIndexHandleFromIndexKV(indexKey, tempIdxVal, len(w.index.Meta().Columns)) + if err != nil { + return false, err + } + tempIdxVal = tempIdxVal.FilterOverwritten() // Extract the operations on the original index and replay them later. @@ -254,19 +258,9 @@ func (w *mergeIndexWorker) fetchTempIndexVals(txn kv.Transaction, taskRange reor continue } - if elem.Handle == nil { - // If the handle is not found in the value of the temp index, it means - // 1) This is not a deletion marker, the handle is in the key or the origin value. - // 2) This is a deletion marker, but the handle is in the key of temp index. - elem.Handle, err = tablecodec.DecodeIndexHandle(indexKey, elem.Value, len(w.index.Meta().Columns)) - if err != nil { - return false, err - } - } - originIdxKey := make([]byte, len(indexKey)) copy(originIdxKey, indexKey) - tablecodec.TempIndexKey2IndexKey(w.index.Meta().ID, originIdxKey) + tablecodec.TempIndexKey2IndexKey(originIdxKey) idxRecord := &temporaryIndexRecord{ handle: elem.Handle, @@ -301,3 +295,18 @@ func (w *mergeIndexWorker) fetchTempIndexVals(txn kv.Transaction, taskRange reor zap.String("taskRange", taskRange.String()), zap.Duration("takeTime", time.Since(startTime))) return w.tmpIdxRecords, nextKey.Next(), taskDone, errors.Trace(err) } + +func decodeTempIndexHandleFromIndexKV(indexKey kv.Key, tmpVal tablecodec.TempIndexValue, idxColLen int) (ret tablecodec.TempIndexValue, err error) { + for _, elem := range tmpVal { + if elem.Handle == nil { + // If the handle is not found in the value of the temp index, it means + // 1) This is not a deletion marker, the handle is in the key or the origin value. + // 2) This is a deletion marker, but the handle is in the key of temp index. + elem.Handle, err = tablecodec.DecodeIndexHandle(indexKey, elem.Value, idxColLen) + if err != nil { + return nil, err + } + } + } + return tmpVal, nil +} diff --git a/ddl/index_merge_tmp_test.go b/ddl/index_merge_tmp_test.go index 42cf75bb4a2cd..8dd5283d6c3de 100644 --- a/ddl/index_merge_tmp_test.go +++ b/ddl/index_merge_tmp_test.go @@ -870,3 +870,44 @@ func TestAddIndexMultipleDelete(t *testing.T) { tk.MustQuery("select * from t;").Check(testkit.Rows()) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/mockDMLExecution")) } + +func TestAddIndexUpdateUntouchedValues(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(id int primary key, b int, k int);") + tk.MustExec("insert into t values (1, 1, 1);") + + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + + d := dom.DDL() + originalCallback := d.GetHook() + defer d.SetHook(originalCallback) + callback := &ddl.TestDDLCallback{} + var runDML bool + callback.OnJobRunBeforeExported = func(job *model.Job) { + if t.Failed() || runDML { + return + } + switch job.SchemaState { + case model.StateWriteReorganization: + _, err := tk1.Exec("begin;") + assert.NoError(t, err) + _, err = tk1.Exec("update t set k=k+1 where id = 1;") + assert.NoError(t, err) + _, err = tk1.Exec("insert into t values (2, 1, 2);") + // Should not report "invalid temp index value". + assert.NoError(t, err) + _, err = tk1.Exec("commit;") + assert.NoError(t, err) + runDML = true + } + } + d.SetHook(callback) + + tk.MustGetErrCode("alter table t add unique index idx(b);", errno.ErrDupEntry) + tk.MustExec("admin check table t;") + tk.MustQuery("select * from t;").Check(testkit.Rows("1 1 2", "2 1 2")) +} diff --git a/ddl/ingest/config.go b/ddl/ingest/config.go index 7fd251a361939..f611369decae0 100644 --- a/ddl/ingest/config.go +++ b/ddl/ingest/config.go @@ -58,6 +58,8 @@ func generateLightningConfig(memRoot MemRoot, jobID int64, unique bool) (*config cfg.Security.CAPath = tidbCfg.Security.ClusterSSLCA cfg.Security.CertPath = tidbCfg.Security.ClusterSSLCert cfg.Security.KeyPath = tidbCfg.Security.ClusterSSLKey + // in DDL scenario, we don't switch import mode + cfg.Cron.SwitchMode = config.Duration{Duration: 0} return cfg, err } diff --git a/ddl/ingest/engine.go b/ddl/ingest/engine.go index ac0287b26637a..fce350df09e40 100644 --- a/ddl/ingest/engine.go +++ b/ddl/ingest/engine.go @@ -204,7 +204,7 @@ func (ei *engineInfo) newWriterContext(workerID int, unique bool) (*WriterContex func (ei *engineInfo) closeWriters() error { var firstErr error - for wid := range ei.writerCache.Keys() { + for _, wid := range ei.writerCache.Keys() { if w, ok := ei.writerCache.Load(wid); ok { _, err := w.Close(ei.ctx) if err != nil { diff --git a/domain/infosync/info.go b/domain/infosync/info.go index 3d45ce691e252..c501d7f16d695 100644 --- a/domain/infosync/info.go +++ b/domain/infosync/info.go @@ -689,6 +689,8 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) { if sm == nil { return } + pl := sm.ShowProcessList() + innerSessionStartTSList := sm.GetInternalSessionStartTSList() // Calculate the lower limit of the start timestamp to avoid extremely old transaction delaying GC. currentVer, err := store.CurrentVersion(kv.GlobalTxnScope) @@ -702,8 +704,18 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) { minStartTS := oracle.GoTimeToTS(now) logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("initial minStartTS", minStartTS), zap.Uint64("StartTSLowerLimit", startTSLowerLimit)) - if ts := sm.GetMinStartTS(startTSLowerLimit); ts > startTSLowerLimit && ts < minStartTS { - minStartTS = ts + for _, info := range pl { + if info.CurTxnStartTS > startTSLowerLimit && info.CurTxnStartTS < minStartTS { + minStartTS = info.CurTxnStartTS + } + } + + for _, innerTS := range innerSessionStartTSList { + logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("Internal Session Transaction StartTS", innerTS)) + kv.PrintLongTimeInternalTxn(now, innerTS, false) + if innerTS > startTSLowerLimit && innerTS < minStartTS { + minStartTS = innerTS + } } is.minStartTS = kv.GetMinInnerTxnStartTS(now, startTSLowerLimit, minStartTS) diff --git a/executor/analyzetest/analyze_test.go b/executor/analyzetest/analyze_test.go index d3aa31494bcfe..a9f8c6f12f915 100644 --- a/executor/analyzetest/analyze_test.go +++ b/executor/analyzetest/analyze_test.go @@ -3188,13 +3188,12 @@ func TestGlobalMemoryControlForAnalyze(t *testing.T) { sql := "analyze table t with 1.0 samplerate;" // Need about 100MB require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/memory/ReadMemStats", `return(536870912)`)) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume", `return(100)`)) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/memory/ReadMemStats")) - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume")) - }() _, err := tk0.Exec(sql) require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) runtime.GC() + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/memory/ReadMemStats")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume")) + tk0.MustExec(sql) } func TestGlobalMemoryControlForAutoAnalyze(t *testing.T) { diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index ee9808700aaec..965708073a097 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -392,23 +392,6 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { if err != nil { return err } - // Change the unique index LOCK into PUT record. - if len(indexKeys) > 0 { - if !e.txn.Valid() { - return kv.ErrInvalidTxn - } - membuf := e.txn.GetMemBuffer() - for _, idxKey := range indexKeys { - handleVal := handleVals[string(idxKey)] - if len(handleVal) == 0 { - continue - } - err = membuf.Set(idxKey, handleVal) - if err != nil { - return err - } - } - } } // Fetch all values. values, err = batchGetter.BatchGet(ctx, keys) @@ -420,7 +403,6 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { if e.lock && rc { existKeys = make([]kv.Key, 0, 2*len(values)) } - changeLockToPutIdxKeys := make([]kv.Key, 0, len(indexKeys)) e.values = make([][]byte, 0, len(values)) for i, key := range keys { val := values[string(key)] @@ -455,7 +437,6 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { // lock primary key for clustered index table is redundant if len(indexKeys) != 0 { existKeys = append(existKeys, indexKeys[i]) - changeLockToPutIdxKeys = append(changeLockToPutIdxKeys, indexKeys[i]) } } } @@ -465,22 +446,6 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { if err != nil { return err } - if len(changeLockToPutIdxKeys) > 0 { - if !e.txn.Valid() { - return kv.ErrInvalidTxn - } - for _, idxKey := range changeLockToPutIdxKeys { - membuf := e.txn.GetMemBuffer() - handleVal := handleVals[string(idxKey)] - if len(handleVal) == 0 { - return kv.ErrNotExist - } - err = membuf.Set(idxKey, handleVal) - if err != nil { - return err - } - } - } } e.handles = handles return nil diff --git a/executor/distsql.go b/executor/distsql.go index 4d760c8592f5d..3d7d54bf6824c 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -75,7 +75,9 @@ type lookupTableTask struct { idxRows *chunk.Chunk cursor int - doneCh chan error + // after the cop task is built, buildDone will be set to the current instant, for Next wait duration statistic. + buildDoneTime time.Time + doneCh chan error // indexOrder map is used to save the original index order for the handles. // Without this map, the original index order might be lost. @@ -793,13 +795,32 @@ func (e *IndexLookUpExecutor) getResultTask() (*lookupTableTask, error) { if e.resultCurr != nil && e.resultCurr.cursor < len(e.resultCurr.rows) { return e.resultCurr, nil } + var ( + enableStats = e.stats != nil + start time.Time + indexFetchedInstant time.Time + ) + if enableStats { + start = time.Now() + } task, ok := <-e.resultCh if !ok { return nil, nil } + if enableStats { + indexFetchedInstant = time.Now() + } if err := <-task.doneCh; err != nil { return nil, err } + if enableStats { + e.stats.NextWaitIndexScan += indexFetchedInstant.Sub(start) + if task.buildDoneTime.After(indexFetchedInstant) { + e.stats.NextWaitTableLookUpBuild += task.buildDoneTime.Sub(indexFetchedInstant) + indexFetchedInstant = task.buildDoneTime + } + e.stats.NextWaitTableLookUpResp += time.Since(indexFetchedInstant) + } // Release the memory usage of last task before we handle a new task. if e.resultCurr != nil { @@ -1122,6 +1143,11 @@ type IndexLookUpRunTimeStats struct { TableRowScan int64 TableTaskNum int64 Concurrency int + + // Record the `Next` call affected wait duration details. + NextWaitIndexScan time.Duration + NextWaitTableLookUpBuild time.Duration + NextWaitTableLookUpResp time.Duration } func (e *IndexLookUpRunTimeStats) String() string { @@ -1145,6 +1171,16 @@ func (e *IndexLookUpRunTimeStats) String() string { } buf.WriteString(fmt.Sprintf(" table_task: {total_time: %v, num: %d, concurrency: %d}", execdetails.FormatDuration(time.Duration(tableScan)), tableTaskNum, concurrency)) } + + if e.NextWaitIndexScan > 0 || e.NextWaitTableLookUpBuild > 0 || e.NextWaitTableLookUpResp > 0 { + if buf.Len() > 0 { + buf.WriteByte(',') + fmt.Fprintf(&buf, " next: {wait_index: %s, wait_table_lookup_build: %s, wait_table_lookup_resp: %s}", + execdetails.FormatDuration(e.NextWaitIndexScan), + execdetails.FormatDuration(e.NextWaitTableLookUpBuild), + execdetails.FormatDuration(e.NextWaitTableLookUpResp)) + } + } return buf.String() } @@ -1165,6 +1201,9 @@ func (e *IndexLookUpRunTimeStats) Merge(other execdetails.RuntimeStats) { e.TaskWait += tmp.TaskWait e.TableRowScan += tmp.TableRowScan e.TableTaskNum += tmp.TableTaskNum + e.NextWaitIndexScan += tmp.NextWaitIndexScan + e.NextWaitTableLookUpBuild += tmp.NextWaitTableLookUpBuild + e.NextWaitTableLookUpResp += tmp.NextWaitTableLookUpResp } // Tp implements the RuntimeStats interface. @@ -1312,6 +1351,7 @@ func getDatumRow(r *chunk.Row, fields []*types.FieldType) []types.Datum { // Then we hold the returning rows and finish this task. func (w *tableWorker) executeTask(ctx context.Context, task *lookupTableTask) error { tableReader, err := w.idxLookup.buildTableReader(ctx, task) + task.buildDoneTime = time.Now() if err != nil { logutil.Logger(ctx).Error("build table reader failed", zap.Error(err)) return err diff --git a/executor/distsql_test.go b/executor/distsql_test.go index 4420a714e96cf..fbc1014dd6a27 100644 --- a/executor/distsql_test.go +++ b/executor/distsql_test.go @@ -358,17 +358,24 @@ func TestPartitionTableRandomlyIndexLookUpReader(t *testing.T) { func TestIndexLookUpStats(t *testing.T) { stats := &executor.IndexLookUpRunTimeStats{ - FetchHandleTotal: int64(5 * time.Second), - FetchHandle: int64(2 * time.Second), - TaskWait: int64(2 * time.Second), - TableRowScan: int64(2 * time.Second), - TableTaskNum: 2, - Concurrency: 1, + FetchHandleTotal: int64(5 * time.Second), + FetchHandle: int64(2 * time.Second), + TaskWait: int64(2 * time.Second), + TableRowScan: int64(2 * time.Second), + TableTaskNum: 2, + Concurrency: 1, + NextWaitIndexScan: time.Second, + NextWaitTableLookUpBuild: 2 * time.Second, + NextWaitTableLookUpResp: 3 * time.Second, } - require.Equal(t, "index_task: {total_time: 5s, fetch_handle: 2s, build: 1s, wait: 2s}, table_task: {total_time: 2s, num: 2, concurrency: 1}", stats.String()) + require.Equal(t, "index_task: {total_time: 5s, fetch_handle: 2s, build: 1s, wait: 2s}"+ + ", table_task: {total_time: 2s, num: 2, concurrency: 1}"+ + ", next: {wait_index: 1s, wait_table_lookup_build: 2s, wait_table_lookup_resp: 3s}", stats.String()) require.Equal(t, stats.Clone().String(), stats.String()) stats.Merge(stats.Clone()) - require.Equal(t, "index_task: {total_time: 10s, fetch_handle: 4s, build: 2s, wait: 4s}, table_task: {total_time: 4s, num: 4, concurrency: 1}", stats.String()) + require.Equal(t, "index_task: {total_time: 10s, fetch_handle: 4s, build: 2s, wait: 4s}"+ + ", table_task: {total_time: 4s, num: 4, concurrency: 1}"+ + ", next: {wait_index: 2s, wait_table_lookup_build: 4s, wait_table_lookup_resp: 6s}", stats.String()) } func TestIndexLookUpGetResultChunk(t *testing.T) { diff --git a/executor/executor.go b/executor/executor.go index 08d137e252b59..98585b905d026 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1937,7 +1937,7 @@ func (e *UnionExec) Close() error { func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars := ctx.GetSessionVars() var sc *stmtctx.StatementContext - if vars.TxnCtx.CouldRetry { + if vars.TxnCtx.CouldRetry || mysql.HasCursorExistsFlag(vars.Status) { // Must construct new statement context object, the retry history need context for every statement. // TODO: Maybe one day we can get rid of transaction retry, then this logic can be deleted. sc = &stmtctx.StatementContext{} diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index ddb0578338c6f..d8aaf71ac8900 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -780,7 +780,7 @@ func TestIndexMerge4PlanCache(t *testing.T) { ps := []*util.ProcessInfo{tkProcess} tk.Session().SetSessionManager(&testkit.MockSessionManager{PS: ps}) res := tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) - require.Len(t, res.Rows(), 7) + require.Len(t, res.Rows(), 8) require.Regexp(t, ".*Selection.*", res.Rows()[1][0]) require.Regexp(t, ".*IndexMerge.*", res.Rows()[2][0]) require.Regexp(t, ".*IndexRangeScan.*", res.Rows()[4][0]) diff --git a/executor/grant.go b/executor/grant.go index 3bae8e4956075..2933126cb855a 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -453,6 +453,9 @@ func (e *GrantExec) grantLevelPriv(priv *ast.PrivElem, user *ast.UserSpec, inter if priv.Priv == mysql.ExtendedPriv { return e.grantDynamicPriv(priv.Name, user, internalSession) } + if priv.Priv == mysql.UsagePriv { + return nil + } switch e.Level.Level { case ast.GrantLevelGlobal: return e.grantGlobalLevel(priv, user, internalSession) @@ -491,10 +494,6 @@ func (e *GrantExec) grantDynamicPriv(privName string, user *ast.UserSpec, intern // grantGlobalLevel manipulates mysql.user table. func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error { - if priv.Priv == 0 || priv.Priv == mysql.UsagePriv { - return nil - } - sql := new(strings.Builder) sqlexec.MustFormatSQL(sql, `UPDATE %n.%n SET `, mysql.SystemDB, mysql.UserTable) err := composeGlobalPrivUpdate(sql, priv.Priv, "Y") @@ -510,9 +509,6 @@ func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec, int // grantDBLevel manipulates mysql.db table. func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error { - if priv.Priv == mysql.UsagePriv { - return nil - } for _, v := range mysql.StaticGlobalOnlyPrivs { if v == priv.Priv { return ErrWrongUsage.GenWithStackByArgs("DB GRANT", "GLOBAL PRIVILEGES") @@ -539,9 +535,6 @@ func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec, interna // grantTableLevel manipulates mysql.tables_priv table. func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error { - if priv.Priv == mysql.UsagePriv { - return nil - } dbName := e.Level.DBName if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB diff --git a/executor/index_merge_reader.go b/executor/index_merge_reader.go index fe62f20fff816..5887ec1c06844 100644 --- a/executor/index_merge_reader.go +++ b/executor/index_merge_reader.go @@ -294,7 +294,7 @@ func (e *IndexMergeReaderExecutor) startIndexMergeProcessWorker(ctx context.Cont func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { // Wait for processWorker to close resultCh. - time.Sleep(2) + time.Sleep(time.Second * 2) // Should use fetchCh instead of resultCh to send error. syncErr(ctx, e.finished, fetchCh, errors.New("testIndexMergeResultChCloseEarly")) }) @@ -365,11 +365,20 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, SetFromInfoSchema(e.ctx.GetInfoSchema()). SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.ctx, &builder.Request, e.partialNetDataSizes[workID])) + var notClosedSelectResult distsql.SelectResult + defer func() { + // To make sure SelectResult.Close() is called even got panic in fetchHandles(). + if notClosedSelectResult != nil { + terror.Call(notClosedSelectResult.Close) + } + }() for parTblIdx, keyRange := range keyRanges { // check if this executor is closed select { + case <-ctx.Done(): + return case <-e.finished: - break + return default: } @@ -388,6 +397,8 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, syncErr(ctx, e.finished, fetchCh, err) return } + notClosedSelectResult = result + failpoint.Inject("testIndexMergePartialIndexWorkerCoprLeak", nil) worker.batchSize = e.maxChunkSize if worker.batchSize > worker.maxBatchSize { worker.batchSize = worker.maxBatchSize @@ -402,6 +413,7 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, if fetchErr != nil { // this error is synced in fetchHandles(), don't sync it again e.feedbacks[workID].Invalidate() } + notClosedSelectResult = nil if err := result.Close(); err != nil { logutil.Logger(ctx).Error("close Select result failed:", zap.Error(err)) } @@ -479,11 +491,20 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, partialTableReader.dagPB = e.dagPBs[workID] } + var tableReaderClosed bool + defer func() { + // To make sure SelectResult.Close() is called even got panic in fetchHandles(). + if !tableReaderClosed { + terror.Call(worker.tableReader.Close) + } + }() for parTblIdx, tbl := range tbls { // check if this executor is closed select { + case <-ctx.Done(): + return case <-e.finished: - break + return default: } @@ -494,6 +515,8 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, syncErr(ctx, e.finished, fetchCh, err) break } + failpoint.Inject("testIndexMergePartialTableWorkerCoprLeak", nil) + tableReaderClosed = false worker.batchSize = e.maxChunkSize if worker.batchSize > worker.maxBatchSize { worker.batchSize = worker.maxBatchSize @@ -511,6 +534,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, // release related resources cancel() + tableReaderClosed = true if err = worker.tableReader.Close(); err != nil { logutil.Logger(ctx).Error("close Select result failed:", zap.Error(err)) } @@ -731,6 +755,12 @@ func (e *IndexMergeReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) e } func (e *IndexMergeReaderExecutor) getResultTask() (*indexMergeTableTask, error) { + failpoint.Inject("testIndexMergeMainReturnEarly", func(_ failpoint.Value) { + // To make sure processWorker make resultCh to be full. + // When main goroutine close finished, processWorker may be stuck when writing resultCh. + time.Sleep(time.Second * 20) + failpoint.Return(nil, errors.New("failpoint testIndexMergeMainReturnEarly")) + }) if e.resultCurr != nil && e.resultCurr.cursor < len(e.resultCurr.rows) { return e.resultCurr, nil } @@ -758,6 +788,7 @@ func handleWorkerPanic(ctx context.Context, finished <-chan struct{}, ch chan<- defer close(ch) } if r == nil { + logutil.BgLogger().Debug("worker finish without panic", zap.Any("worker", worker)) return } @@ -820,7 +851,20 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- failpoint.Inject("testIndexMergePanicProcessWorkerUnion", nil) distinctHandles := make(map[int64]*kv.HandleMap) - for task := range fetchCh { + for { + var ok bool + var task *indexMergeTableTask + select { + case <-ctx.Done(): + return + case <-finished: + return + case task, ok = <-fetchCh: + if !ok { + return + } + } + select { case err := <-task.doneCh: // If got error from partialIndexWorker/partialTableWorker, stop processing. @@ -856,7 +900,7 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- if len(fhs) == 0 { continue } - task := &indexMergeTableTask{ + task = &indexMergeTableTask{ lookupTableTask: lookupTableTask{ handles: fhs, doneCh: make(chan error, 1), @@ -867,13 +911,27 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- if w.stats != nil { w.stats.IndexMergeProcess += time.Since(start) } + failpoint.Inject("testIndexMergeProcessWorkerUnionHang", func(_ failpoint.Value) { + for i := 0; i < cap(resultCh); i++ { + select { + case resultCh <- &indexMergeTableTask{}: + default: + } + } + }) select { case <-ctx.Done(): return case <-finished: return case workCh <- task: - resultCh <- task + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + } } } } @@ -972,6 +1030,14 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte zap.Int("parTblIdx", parTblIdx), zap.Int("task.handles", len(task.handles))) } } + failpoint.Inject("testIndexMergeProcessWorkerIntersectionHang", func(_ failpoint.Value) { + for i := 0; i < cap(resultCh); i++ { + select { + case resultCh <- &indexMergeTableTask{}: + default: + } + } + }) for _, task := range tasks { select { case <-ctx.Done(): @@ -979,7 +1045,13 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte case <-finished: return case workCh <- task: - resultCh <- task + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + } } } } @@ -1038,29 +1110,47 @@ func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fet }, handleWorkerPanic(ctx, finished, resultCh, errCh, partTblIntersectionWorkerType)) workers = append(workers, worker) } -loop: - for task := range fetchCh { + defer func() { + for _, processWorker := range workers { + close(processWorker.workerCh) + } + wg.Wait() + }() + for { + var ok bool + var task *indexMergeTableTask + select { + case <-ctx.Done(): + return + case <-finished: + return + case task, ok = <-fetchCh: + if !ok { + return + } + } + select { case err := <-task.doneCh: // If got error from partialIndexWorker/partialTableWorker, stop processing. if err != nil { syncErr(ctx, finished, resultCh, err) - break loop + return } default: } select { + case <-ctx.Done(): + return + case <-finished: + return case workers[task.parTblIdx%workerCnt].workerCh <- task: case <-errCh: // If got error from intersectionProcessWorker, stop processing. - break loop + return } } - for _, processWorker := range workers { - close(processWorker.workerCh) - } - wg.Wait() } type partialIndexWorker struct { @@ -1209,12 +1299,14 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** for { waitStart := time.Now() select { + case <-ctx.Done(): + return + case <-w.finished: + return case *task, ok = <-w.workCh: if !ok { return } - case <-w.finished: - return } // Make sure panic failpoint is after fetch task from workCh. // Otherwise cannot send error to task.doneCh. @@ -1235,13 +1327,20 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** atomic.AddInt64(&w.stats.TableTaskNum, 1) } failpoint.Inject("testIndexMergePickAndExecTaskPanic", nil) - (*task).doneCh <- err + select { + case <-ctx.Done(): + return + case <-w.finished: + return + case (*task).doneCh <- err: + } } } func (w *indexMergeTableScanWorker) handleTableScanWorkerPanic(ctx context.Context, finished <-chan struct{}, task **indexMergeTableTask, worker string) func(r interface{}) { return func(r interface{}) { if r == nil { + logutil.BgLogger().Debug("worker finish without panic", zap.Any("worker", worker)) return } diff --git a/executor/index_merge_reader_test.go b/executor/index_merge_reader_test.go index be1ff66a163ab..46a1206460074 100644 --- a/executor/index_merge_reader_test.go +++ b/executor/index_merge_reader_test.go @@ -798,6 +798,49 @@ func TestIntersectionMemQuota(t *testing.T) { require.Contains(t, err.Error(), "Out Of Memory Quota!") } +func setupPartitionTableHelper(tk *testkit.TestKit) { + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int, c2 bigint, c3 bigint, primary key(c1), key(c2), key(c3));") + insertStr := "insert into t1 values(0, 0, 0)" + for i := 1; i < 1000; i++ { + insertStr += fmt.Sprintf(", (%d, %d, %d)", i, i, i) + } + tk.MustExec(insertStr) + tk.MustExec("analyze table t1;") + tk.MustExec("set tidb_partition_prune_mode = 'dynamic'") +} + +func TestIndexMergeProcessWorkerHang(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + setupPartitionTableHelper(tk) + + var err error + sql := "select /*+ use_index_merge(t1) */ c1 from t1 where c1 < 900 or c2 < 1000;" + res := tk.MustQuery("explain " + sql).Rows() + require.Contains(t, res[1][0], "IndexMerge") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerUnionHang", "return(true)")) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergeMainReturnEarly") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerUnionHang")) + + sql = "select /*+ use_index_merge(t1, c2, c3) */ c1 from t1 where c2 < 900 and c3 < 1000;" + res = tk.MustQuery("explain " + sql).Rows() + require.Contains(t, res[1][0], "IndexMerge") + require.Contains(t, res[1][4], "intersection") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly", "return()")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerIntersectionHang", "return(true)")) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergeMainReturnEarly") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeMainReturnEarly")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeProcessWorkerIntersectionHang")) +} + func TestIndexMergePanic(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) @@ -811,16 +854,7 @@ func TestIndexMergePanic(t *testing.T) { tk.MustExec("select /*+ use_index_merge(t1, primary, c2, c3) */ c1 from t1 where c1 < 100 or c2 < 100") require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergeResultChCloseEarly")) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1") - tk.MustExec("create table t1(c1 int, c2 bigint, c3 bigint, primary key(c1), key(c2), key(c3)) partition by hash(c1) partitions 10;") - insertStr := "insert into t1 values(0, 0, 0)" - for i := 1; i < 1000; i++ { - insertStr += fmt.Sprintf(", (%d, %d, %d)", i, i, i) - } - tk.MustExec(insertStr) - tk.MustExec("analyze table t1;") - tk.MustExec("set tidb_partition_prune_mode = 'dynamic'") + setupPartitionTableHelper(tk) minV := 200 maxV := 1000 @@ -881,3 +915,25 @@ func TestIndexMergePanic(t *testing.T) { require.NoError(t, failpoint.Disable(fp)) } } + +func TestIndexMergeCoprGoroutinesLeak(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + setupPartitionTableHelper(tk) + + var err error + sql := "select /*+ use_index_merge(t1) */ c1 from t1 where c1 < 900 or c2 < 1000;" + res := tk.MustQuery("explain " + sql).Rows() + require.Contains(t, res[1][0], "IndexMerge") + + // If got goroutines leak in coprocessor, ci will fail. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergePartialTableWorkerCoprLeak", `panic("testIndexMergePartialTableWorkerCoprLeak")`)) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergePartialTableWorkerCoprLeak") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergePartialTableWorkerCoprLeak")) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergePartialIndexWorkerCoprLeak", `panic("testIndexMergePartialIndexWorkerCoprLeak")`)) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergePartialIndexWorkerCoprLeak") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergePartialIndexWorkerCoprLeak")) +} diff --git a/executor/insert.go b/executor/insert.go index 83b486d5d1020..325e8df01e367 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -159,7 +159,7 @@ func prefetchConflictedOldRows(ctx context.Context, txn kv.Transaction, rows []t for _, r := range rows { for _, uk := range r.uniqueKeys { if val, found := values[string(uk.newKey)]; found { - if isTemp, _ := tablecodec.CheckTempIndexKey(uk.newKey); isTemp { + if tablecodec.IsTempIndexKey(uk.newKey) { // If it is a temp index, the value cannot be decoded by DecodeHandleInUniqueIndexValue. // Since this function is an optimization, we can skip prefetching the rows referenced by // temp indexes. diff --git a/executor/point_get.go b/executor/point_get.go index 3e3cddb08d9ba..bb958055dedd8 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -282,18 +282,6 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { return nil } - // Change the unique index LOCK into PUT record. - if e.lock { - if !e.txn.Valid() { - return kv.ErrInvalidTxn - } - memBuffer := e.txn.GetMemBuffer() - err = memBuffer.Set(e.idxKey, e.handleVal) - if err != nil { - return err - } - } - var iv kv.Handle iv, err = tablecodec.DecodeHandleInUniqueIndexValue(e.handleVal, e.tblInfo.IsCommonHandle) if err != nil { diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 8ee6a0aa867c4..f428d2b9a891c 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -169,11 +169,9 @@ func TestIssue29850(t *testing.T) { ps = []*util.ProcessInfo{tkProcess} tk.Session().SetSessionManager(&testkit.MockSessionManager{PS: ps}) tk.MustQuery(fmt.Sprintf("explain for connection %d", tkProcess.ID)).Check(testkit.Rows( // cannot use PointGet since it contains a range condition - `Selection_7 1.00 root ge(test.t.a, 1), le(test.t.a, 1)`, - `└─TableReader_6 1.00 root data:TableRangeScan_5`, - ` └─TableRangeScan_5 1.00 cop[tikv] table:t range:[1,1], keep order:false, stats:pseudo`)) + `Point_Get_5 1.00 root table:t handle:1`)) tk.MustQuery(`execute stmt using @a1, @a2`).Check(testkit.Rows("1", "2")) - tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("1")) + tk.MustQuery(`select @@last_plan_from_cache`).Check(testkit.Rows("0")) tk.MustExec(`prepare stmt from 'select * from t where a=? or a=?'`) tk.MustQuery(`execute stmt using @a1, @a1`).Check(testkit.Rows("1")) @@ -181,9 +179,7 @@ func TestIssue29850(t *testing.T) { ps = []*util.ProcessInfo{tkProcess} tk.Session().SetSessionManager(&testkit.MockSessionManager{PS: ps}) tk.MustQuery(fmt.Sprintf("explain for connection %d", tkProcess.ID)).Check(testkit.Rows( // cannot use PointGet since it contains a or condition - `Selection_7 1.00 root or(eq(test.t.a, 1), eq(test.t.a, 1))`, - `└─TableReader_6 1.00 root data:TableRangeScan_5`, - ` └─TableRangeScan_5 1.00 cop[tikv] table:t range:[1,1], keep order:false, stats:pseudo`)) + `Point_Get_5 1.00 root table:t handle:1`)) tk.MustQuery(`execute stmt using @a1, @a2`).Check(testkit.Rows("1", "2")) } diff --git a/executor/revoke.go b/executor/revoke.go index 337e387c5b28f..9063206ffd52a 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -180,6 +180,9 @@ func (e *RevokeExec) revokeOneUser(internalSession sessionctx.Context, user, hos } func (e *RevokeExec) revokePriv(internalSession sessionctx.Context, priv *ast.PrivElem, user, host string) error { + if priv.Priv == mysql.UsagePriv { + return nil + } switch e.Level.Level { case ast.GrantLevelGlobal: return e.revokeGlobalPriv(internalSession, priv, user, host) diff --git a/executor/revoke_test.go b/executor/revoke_test.go index 635fa18552df5..fcc53b5d291fe 100644 --- a/executor/revoke_test.go +++ b/executor/revoke_test.go @@ -271,3 +271,18 @@ func TestRevokeOnNonExistTable(t *testing.T) { tk.MustExec("DROP TABLE t1;") tk.MustExec("REVOKE ALTER ON d1.t1 FROM issue28533;") } + +// Check https://github.com/pingcap/tidb/issues/41773. +func TestIssue41773(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table if not exists xx (id int)") + tk.MustExec("CREATE USER 't1234'@'%' IDENTIFIED BY 'sNGNQo12fEHe0n3vU';") + tk.MustExec("GRANT USAGE ON * TO 't1234'@'%';") + tk.MustExec("GRANT USAGE ON test.* TO 't1234'@'%';") + tk.MustExec("GRANT USAGE ON test.xx TO 't1234'@'%';") + tk.MustExec("REVOKE USAGE ON * FROM 't1234'@'%';") + tk.MustExec("REVOKE USAGE ON test.* FROM 't1234'@'%';") + tk.MustExec("REVOKE USAGE ON test.xx FROM 't1234'@'%';") +} diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index bed48c0e59096..de0ddd8d61c37 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -1565,19 +1565,29 @@ func allowCmpArgsRefining4PlanCache(ctx sessionctx.Context, args []Expression) ( return true // plan-cache disabled or no parameter in these args } - // For these 2 cases below which may affect the index selection a lot, skip plan-cache, - // and for all other cases, skip the refining. - // 1. int-expr string-const - // 2. int-expr float/double/decimal-const + // For these 2 cases below, we skip the refining: + // 1. year-expr const + // 2. int-expr string/float/double/decimal-const for conIdx := 0; conIdx < 2; conIdx++ { - if args[1-conIdx].GetType().EvalType() != types.ETInt { - continue // not a int-expr - } if _, isCon := args[conIdx].(*Constant); !isCon { continue // not a constant } + + // case 1: year-expr const + // refine `year < 12` to `year < 2012` to guarantee the correctness. + // see https://github.com/pingcap/tidb/issues/41626 for more details. + exprType := args[1-conIdx].GetType() + if exprType.GetType() == mysql.TypeYear { + reason := errors.Errorf("skip plan-cache: '%v' may be converted to INT", args[conIdx].String()) + ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason) + return true + } + + // case 2: int-expr string/float/double/decimal-const + // refine `int_key < 1.1` to `int_key < 2` to generate RangeScan instead of FullScan. conType := args[conIdx].GetType().EvalType() - if conType == types.ETString || conType == types.ETReal || conType == types.ETDecimal { + if exprType.EvalType() == types.ETInt && + (conType == types.ETString || conType == types.ETReal || conType == types.ETDecimal) { reason := errors.Errorf("skip plan-cache: '%v' may be converted to INT", args[conIdx].String()) ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason) return true diff --git a/expression/integration_test.go b/expression/integration_test.go index 55c8f389a5df3..68507621f1e26 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3743,18 +3743,23 @@ func TestShardIndexOnTiFlash(t *testing.T) { } } } + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'") tk.MustExec("set @@session.tidb_enforce_mpp = 1") rows := tk.MustQuery("explain select max(b) from t").Rows() for _, row := range rows { line := fmt.Sprintf("%v", row) - require.NotContains(t, line, "tiflash") + if strings.Contains(line, "TableFullScan") { + require.Contains(t, line, "tiflash") + } } tk.MustExec("set @@session.tidb_enforce_mpp = 0") tk.MustExec("set @@session.tidb_allow_mpp = 0") rows = tk.MustQuery("explain select max(b) from t").Rows() for _, row := range rows { line := fmt.Sprintf("%v", row) - require.NotContains(t, line, "tiflash") + if strings.Contains(line, "TableFullScan") { + require.NotContains(t, line, "mpp[tiflash]") + } } } diff --git a/parser/mysql/type.go b/parser/mysql/type.go index f79be8ab30d96..8a2531d870d3e 100644 --- a/parser/mysql/type.go +++ b/parser/mysql/type.go @@ -74,6 +74,7 @@ const ( PreventNullInsertFlag uint = 1 << 20 /* Prevent this Field from inserting NULL values */ EnumSetAsIntFlag uint = 1 << 21 /* Internal: Used for inferring enum eval type. */ DropColumnIndexFlag uint = 1 << 22 /* Internal: Used for indicate the column is being dropped with index */ + GeneratedColumnFlag uint = 1 << 23 /* Internal: TiFlash will check this flag and add a placeholder for this column */ ) // TypeInt24 bounds. diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 768a8c20fc0b5..3040563a564e0 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2255,7 +2255,7 @@ func pushLimitOrTopNForcibly(p LogicalPlan) bool { } func (lt *LogicalTopN) getPhysTopN(_ *property.PhysicalProperty) []PhysicalPlan { - allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopDoubleReadTaskType} + allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} if !pushLimitOrTopNForcibly(lt) { allTaskTypes = append(allTaskTypes, property.RootTaskType) } @@ -2281,7 +2281,7 @@ func (lt *LogicalTopN) getPhysLimits(_ *property.PhysicalProperty) []PhysicalPla return nil } - allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopDoubleReadTaskType} + allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} if !pushLimitOrTopNForcibly(lt) { allTaskTypes = append(allTaskTypes, property.RootTaskType) } @@ -2606,7 +2606,7 @@ func (la *LogicalAggregation) getEnforcedStreamAggs(prop *property.PhysicalPrope if !prop.IsPrefix(childProp) { return enforcedAggs } - taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopDoubleReadTaskType} + taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} if la.HasDistinct() { // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. @@ -2839,7 +2839,7 @@ func (la *LogicalAggregation) getHashAggs(prop *property.PhysicalProperty) []Phy return nil } hashAggs := make([]PhysicalPlan, 0, len(prop.GetAllPossibleChildTaskTypes())) - taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopDoubleReadTaskType} + taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} canPushDownToTiFlash := la.canPushToCop(kv.TiFlash) canPushDownToMPP := canPushDownToTiFlash && la.ctx.GetSessionVars().IsMPPAllowed() && la.checkCanPushDownToMPP() if la.HasDistinct() { @@ -2962,7 +2962,7 @@ func (p *LogicalLimit) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([] return nil, true, nil } - allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopDoubleReadTaskType} + allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} if !pushLimitOrTopNForcibly(p) { allTaskTypes = append(allTaskTypes, property.RootTaskType) } diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 95d7698fbfe72..6a3918ce479fb 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -327,6 +327,19 @@ func getTaskPlanCost(t task, op *physicalOptimizeOp) (float64, bool, error) { default: return 0, false, errors.New("unknown task type") } + if t.plan() == nil { + // It's a very special case for index merge case. + cost := 0.0 + copTsk := t.(*copTask) + for _, partialScan := range copTsk.idxMergePartPlans { + partialCost, err := getPlanCost(partialScan, taskType, NewDefaultPlanCostOption().WithOptimizeTracer(op)) + if err != nil { + return 0, false, err + } + cost += partialCost + } + return cost, false, nil + } cost, err := getPlanCost(t.plan(), taskType, NewDefaultPlanCostOption().WithOptimizeTracer(op)) return cost, false, err } @@ -938,10 +951,6 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter canConvertPointGet := len(path.Ranges) > 0 && path.StoreType == kv.TiKV && ds.isPointGetConvertableSchema() - if canConvertPointGet && expression.MaybeOverOptimized4PlanCache(ds.ctx, path.AccessConds) { - canConvertPointGet = ds.canConvertToPointGetForPlanCache(path) - } - if canConvertPointGet && !path.IsIntHandlePath { // We simply do not build [batch] point get for prefix indexes. This can be optimized. canConvertPointGet = path.Index.Unique && !path.Index.HasPrefixIndex() @@ -997,6 +1006,13 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter } else { pointGetTask = ds.convertToBatchPointGet(prop, candidate, hashPartColName, opt) } + + // Batch/PointGet plans may be over-optimized, like `a>=1(?) and a<=1(?)` --> `a=1` --> PointGet(a=1). + // For safety, prevent these plans from the plan cache here. + if !pointGetTask.invalid() && expression.MaybeOverOptimized4PlanCache(ds.ctx, candidate.path.AccessConds) && !ds.isSafePointGetPlan4PlanCache(candidate.path) { + ds.ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("Batch/PointGet plans may be over-optimized")) + } + appendCandidate(ds, pointGetTask, prop, opt) if !pointGetTask.invalid() { cntPlan++ @@ -1076,12 +1092,11 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter return } -func (ds *DataSource) canConvertToPointGetForPlanCache(path *util.AccessPath) bool { +func (ds *DataSource) isSafePointGetPlan4PlanCache(path *util.AccessPath) bool { // PointGet might contain some over-optimized assumptions, like `a>=1 and a<=1` --> `a=1`, but // these assumptions may be broken after parameters change. - // So for safety, we narrow down the scope and just generate PointGet in some particular and simple scenarios. - // scenario 1: each column corresponds to a single EQ, `a=1 and b=2 and c=3` --> `[1, 2, 3]` + // safe scenario 1: each column corresponds to a single EQ, `a=1 and b=2 and c=3` --> `[1, 2, 3]` if len(path.Ranges) > 0 && path.Ranges[0].Width() == len(path.AccessConds) { for _, accessCond := range path.AccessConds { f, ok := accessCond.(*expression.ScalarFunction) @@ -1098,13 +1113,16 @@ func (ds *DataSource) canConvertToPointGetForPlanCache(path *util.AccessPath) bo } func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, candidate *candidatePath, _ *physicalOptimizeOp) (task task, err error) { - if prop.TaskTp != property.RootTaskType || !prop.IsSortItemEmpty() { + if prop.IsFlashProp() || prop.TaskTp == property.CopSingleReadTaskType || !prop.IsSortItemEmpty() { + return invalidTask, nil + } + if prop.TaskTp == property.CopMultiReadTaskType && candidate.path.IndexMergeIsIntersection { return invalidTask, nil } path := candidate.path scans := make([]PhysicalPlan, 0, len(path.PartialIndexPaths)) cop := &copTask{ - indexPlanFinished: true, + indexPlanFinished: false, tblColHists: ds.TblColHists, } cop.partitionInfo = PartitionInfo{ @@ -1128,7 +1146,7 @@ func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, c } ts, remainingFilters, err := ds.buildIndexMergeTableScan(prop, path.TableFilters, totalRowCount) if err != nil { - return nil, err + return invalidTask, err } cop.tablePlan = ts cop.idxMergePartPlans = scans @@ -1136,8 +1154,17 @@ func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, c if remainingFilters != nil { cop.rootTaskConds = remainingFilters } - task = cop.convertToRootTask(ds.ctx) - ds.addSelection4PlanCache(task.(*rootTask), ds.tableStats.ScaleByExpectCnt(totalRowCount), prop) + _, pureTableScan := ts.(*PhysicalTableScan) + if prop.TaskTp != property.RootTaskType && (len(remainingFilters) > 0 || !pureTableScan) { + return invalidTask, nil + } + if prop.TaskTp == property.RootTaskType { + cop.indexPlanFinished = true + task = cop.convertToRootTask(ds.ctx) + ds.addSelection4PlanCache(task.(*rootTask), ds.tableStats.ScaleByExpectCnt(totalRowCount), prop) + } else { + task = cop + } return task, nil } @@ -1420,7 +1447,7 @@ func (ds *DataSource) convertToIndexScan(prop *property.PhysicalProperty, if prop.TaskTp == property.CopSingleReadTaskType { return invalidTask, nil } - } else if prop.TaskTp == property.CopDoubleReadTaskType { + } else if prop.TaskTp == property.CopMultiReadTaskType { // If it's parent requires double read task, return max cost. return invalidTask, nil } @@ -1960,7 +1987,7 @@ func (ds *DataSource) isPointGetPath(path *util.AccessPath) bool { // convertToTableScan converts the DataSource to table scan. func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candidate *candidatePath, _ *physicalOptimizeOp) (task task, err error) { // It will be handled in convertToIndexScan. - if prop.TaskTp == property.CopDoubleReadTaskType { + if prop.TaskTp == property.CopMultiReadTaskType { return invalidTask, nil } if !prop.IsSortItemEmpty() && !candidate.isMatchProp { @@ -1972,15 +1999,9 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid return invalidTask, nil } if ts.StoreType == kv.TiFlash { - for _, col := range ts.schema.Columns { - // In theory, TiFlash does not support virtual expr, but in non-mpp mode, if the cop request only contain table scan, then - // TiDB will fill the virtual column after decoding the cop response(executor.FillVirtualColumnValue), that is to say, the virtual - // columns in Cop request is just a placeholder, so TiFlash can support virtual column in cop request mode. However, virtual column - // with TiDBShard is special, it can be added using create index statement, TiFlash's ddl does not handle create index statement, so - // there is a chance that the TiDBShard's virtual column is not seen by TiFlash, in this case, TiFlash will throw column not found error - if ds.containExprPrefixUk && expression.GcColumnExprIsTidbShard(col.VirtualExpr) { - ds.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because column `" + col.OrigName + "` is a virtual column which is not supported now.") - return invalidTask, nil + for _, col := range ts.Columns { + if col.IsGenerated() && !col.GeneratedStored { + col.AddFlag(mysql.GeneratedColumnFlag) } } } @@ -2047,7 +2068,7 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid func (ds *DataSource) convertToSampleTable(prop *property.PhysicalProperty, candidate *candidatePath, _ *physicalOptimizeOp) (task task, err error) { - if prop.TaskTp == property.CopDoubleReadTaskType { + if prop.TaskTp == property.CopMultiReadTaskType { return invalidTask, nil } if !prop.IsSortItemEmpty() && !candidate.isMatchProp { @@ -2074,7 +2095,7 @@ func (ds *DataSource) convertToPointGet(prop *property.PhysicalProperty, candida if !prop.IsSortItemEmpty() && !candidate.isMatchProp { return invalidTask } - if prop.TaskTp == property.CopDoubleReadTaskType && candidate.path.IsSingleScan || + if prop.TaskTp == property.CopMultiReadTaskType && candidate.path.IsSingleScan || prop.TaskTp == property.CopSingleReadTaskType && !candidate.path.IsSingleScan { return invalidTask } @@ -2152,7 +2173,7 @@ func (ds *DataSource) convertToBatchPointGet(prop *property.PhysicalProperty, if !prop.IsSortItemEmpty() && !candidate.isMatchProp { return invalidTask } - if prop.TaskTp == property.CopDoubleReadTaskType && candidate.path.IsSingleScan || + if prop.TaskTp == property.CopMultiReadTaskType && candidate.path.IsSingleScan || prop.TaskTp == property.CopSingleReadTaskType && !candidate.path.IsSingleScan { return invalidTask } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index a355d0782cbd9..ae06750fe9624 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -5702,14 +5702,18 @@ func TestIssue29221(t *testing.T) { tk.MustQuery("explain format = 'brief' select * from t where a = 1 or b = 1;").Check(testkit.Rows( "Limit 3.00 root offset:0, count:3", "└─IndexMerge 3.00 root type: union", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", " └─TableRowIDScan(Probe) 3.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustQuery("explain format = 'brief' select /*+ use_index_merge(t) */ * from t where a = 1 or b = 1;").Check(testkit.Rows( "Limit 3.00 root offset:0, count:3", "└─IndexMerge 3.00 root type: union", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", " └─TableRowIDScan(Probe) 3.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustExec("set @@session.sql_select_limit=18446744073709551615;") tk.MustQuery("explain format = 'brief' select * from t where a = 1 or b = 1;").Check(testkit.Rows( @@ -5720,8 +5724,10 @@ func TestIssue29221(t *testing.T) { tk.MustQuery("explain format = 'brief' select * from t where a = 1 or b = 1 limit 3;").Check(testkit.Rows( "Limit 3.00 root offset:0, count:3", "└─IndexMerge 3.00 root type: union", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", " └─TableRowIDScan(Probe) 3.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustQuery("explain format = 'brief' select /*+ use_index_merge(t) */ * from t where a = 1 or b = 1;").Check(testkit.Rows( "IndexMerge 19.99 root type: union", @@ -5731,8 +5737,10 @@ func TestIssue29221(t *testing.T) { tk.MustQuery("explain format = 'brief' select /*+ use_index_merge(t) */ * from t where a = 1 or b = 1 limit 3;").Check(testkit.Rows( "Limit 3.00 root offset:0, count:3", "└─IndexMerge 3.00 root type: union", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", - " ├─IndexRangeScan(Build) 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_a(a) range:[1,1], keep order:false, stats:pseudo", + " ├─Limit(Build) 1.50 cop[tikv] offset:0, count:3", + " │ └─IndexRangeScan 1.50 cop[tikv] table:t, index:idx_b(b) range:[1,1], keep order:false, stats:pseudo", " └─TableRowIDScan(Probe) 3.00 cop[tikv] table:t keep order:false, stats:pseudo")) } diff --git a/planner/core/partition_prune.go b/planner/core/partition_prune.go index 3ab266340829d..2dec9c62d19e2 100644 --- a/planner/core/partition_prune.go +++ b/planner/core/partition_prune.go @@ -40,7 +40,7 @@ func PartitionPruning(ctx sessionctx.Context, tbl table.PartitionedTable, conds ret := s.convertToIntSlice(rangeOr, pi, partitionNames) return ret, nil case model.PartitionTypeList: - return s.pruneListPartition(ctx, tbl, partitionNames, conds) + return s.pruneListPartition(ctx, tbl, partitionNames, conds, columns) } return []int{FullRange}, nil } diff --git a/planner/core/partition_pruner_test.go b/planner/core/partition_pruner_test.go index 30340948f5d65..97b8595e5422d 100644 --- a/planner/core/partition_pruner_test.go +++ b/planner/core/partition_pruner_test.go @@ -688,6 +688,42 @@ func TestRangePartitionPredicatePruner(t *testing.T) { } } +func TestIssue42135(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`create database issue42135`) + tk.MustExec(`use issue42135`) + tk.MustExec("CREATE TABLE `tx1` (`ID` varchar(13), `a` varchar(13), `b` varchar(4000), `ltype` int(5) NOT NULL)") + tk.MustExec("CREATE TABLE `tx2` (`ID` varchar(13), `rid` varchar(12), `a` varchar(9), `b` varchar(8), `c` longtext, `d` varchar(12), `ltype` int(5) NOT NULL) PARTITION BY LIST (`ltype`) (PARTITION `p1` VALUES IN (501), PARTITION `p2` VALUES IN (502))") + tk.MustExec("insert into tx1 values(1,1,1,501)") + tk.MustExec("insert into tx2 values(1,1,1,1,1,1,501)") + tk.MustExec(`analyze table tx1`) + tk.MustExec(`analyze table tx2`) + tk.MustQuery(`select * from tx1 inner join tx2 on tx1.ID=tx2.ID and tx1.ltype=tx2.ltype where tx2.rid='1'`).Check(testkit.Rows("1 1 1 501 1 1 1 1 1 1 501")) + tk.MustQuery(`explain format='brief' select * from tx1 inner join tx2 on tx1.ID=tx2.ID and tx1.ltype=tx2.ltype where tx2.rid='1'`).Check(testkit.Rows(""+ + "HashJoin 1.00 root inner join, equal:[eq(issue42135.tx1.id, issue42135.tx2.id) eq(issue42135.tx1.ltype, issue42135.tx2.ltype)]", + `├─TableReader(Build) 1.00 root data:Selection`, + `│ └─Selection 1.00 cop[tikv] not(isnull(issue42135.tx1.id))`, + `│ └─TableFullScan 1.00 cop[tikv] table:tx1 keep order:false`, + `└─TableReader(Probe) 1.00 root partition:all data:Selection`, + ` └─Selection 1.00 cop[tikv] eq(issue42135.tx2.rid, "1"), not(isnull(issue42135.tx2.id))`, + ` └─TableFullScan 1.00 cop[tikv] table:tx2 keep order:false`)) + + tk.MustExec(`drop table tx2`) + tk.MustExec("CREATE TABLE `tx2` (`ID` varchar(13), `rid` varchar(12), `a` varchar(9), `b` varchar(8), `c` longtext, `d` varchar(12), `ltype` int(5) NOT NULL) PARTITION BY LIST COLUMNS (`ltype`,d) (PARTITION `p1` VALUES IN ((501,1)), PARTITION `p2` VALUES IN ((502,1)))") + tk.MustExec("insert into tx2 values(1,1,1,1,1,1,501)") + tk.MustExec(`analyze table tx2`) + tk.MustQuery(`select * from tx1 inner join tx2 on tx1.ID=tx2.ID and tx1.ltype=tx2.ltype where tx2.rid='1'`).Check(testkit.Rows("1 1 1 501 1 1 1 1 1 1 501")) + tk.MustQuery(`explain format='brief' select * from tx1 inner join tx2 on tx1.ID=tx2.ID and tx1.ltype=tx2.ltype where tx2.rid='1'`).Check(testkit.Rows(""+ + "HashJoin 1.00 root inner join, equal:[eq(issue42135.tx1.id, issue42135.tx2.id) eq(issue42135.tx1.ltype, issue42135.tx2.ltype)]", + "├─TableReader(Build) 1.00 root data:Selection", + "│ └─Selection 1.00 cop[tikv] not(isnull(issue42135.tx1.id))", + "│ └─TableFullScan 1.00 cop[tikv] table:tx1 keep order:false", + "└─TableReader(Probe) 1.00 root partition:all data:Selection", + " └─Selection 1.00 cop[tikv] eq(issue42135.tx2.rid, \"1\"), not(isnull(issue42135.tx2.id))", + " └─TableFullScan 1.00 cop[tikv] table:tx2 keep order:false")) +} + func TestHashPartitionPruning(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 7e5b13b588e16..20ea415ceb4c3 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -2533,3 +2533,30 @@ func TestCountStarForTiFlash(t *testing.T) { tk.MustQuery("explain format = 'brief' " + ts).Check(testkit.Rows(output[i].Plan...)) } } + +func TestIndexMergeOrderPushDown(t *testing.T) { + var ( + input []string + output []struct { + SQL string + Plan []string + Warning []string + } + ) + planSuiteData := core.GetPlanSuiteData() + planSuiteData.LoadTestCases(t, &input, &output) + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("set tidb_cost_model_version=1") + tk.MustExec("create table t (a int, b int, c int, index idx(a, c), index idx2(b, c))") + + for i, ts := range input { + testdata.OnRecord(func() { + output[i].SQL = ts + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format = 'brief' " + ts).Rows()) + }) + tk.MustQuery("explain format = 'brief' " + ts).Check(testkit.Rows(output[i].Plan...)) + } +} diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go index a2002a6f18010..b6e5c3824781b 100644 --- a/planner/core/plan_cache.go +++ b/planner/core/plan_cache.go @@ -725,6 +725,11 @@ func containShuffleOperator(p PhysicalPlan) bool { if _, isShuffleRecv := p.(*PhysicalShuffleReceiverStub); isShuffleRecv { return true } + for _, child := range p.Children() { + if containShuffleOperator(child) { + return true + } + } return false } diff --git a/planner/core/plan_cache_test.go b/planner/core/plan_cache_test.go index b63e6f0f2f3d0..2aa0a6c021387 100644 --- a/planner/core/plan_cache_test.go +++ b/planner/core/plan_cache_test.go @@ -112,6 +112,18 @@ func TestGeneralPlanCacheBasically(t *testing.T) { } } +func TestIssue41626(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a year)`) + tk.MustExec(`insert into t values (2000)`) + tk.MustExec(`prepare st from 'select * from t where a=? and b<=?'") + tk.MustExec("set @a=1, @b=1") + tk.MustExec("execute st using @a, @b") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&testkit.MockSessionManager{PS: ps}) + rows = tk.MustQuery(fmt.Sprintf("explain for connection %d", tkProcess.ID)).Rows() + require.Equal(t, rows[0][0], "Point_Get_5") // use Point_Get_5 + tk.MustExec("execute st using @a, @b") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // cannot hit + + // safe PointGet + tk.MustExec("prepare st from 'select * from t where a=1 and b=? and c Selection + require.Contains(t, rows[1][0], "Point_Get") + tk.MustExec("execute st using @a, @b") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) // can hit +} + +func TestIssue41828(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`CREATE TABLE IDT_MULTI15840STROBJSTROBJ ( + COL1 enum('aa', 'zzz') DEFAULT NULL, + COL2 smallint(6) DEFAULT NULL, + COL3 date DEFAULT NULL, + KEY U_M_COL4 (COL1,COL2), + KEY U_M_COL5 (COL3,COL2))`) + + tk.MustExec(`INSERT INTO IDT_MULTI15840STROBJSTROBJ VALUES ('zzz',1047,'6115-06-05'),('zzz',-23221,'4250-09-03'),('zzz',27138,'1568-07-30'),('zzz',-30903,'6753-08-21'),('zzz',-26875,'6117-10-10')`) + tk.MustExec(`prepare stmt from 'select * from IDT_MULTI15840STROBJSTROBJ where col3 <=> ? or col1 in (?, ?, ?) and col2 not between ? and ?'`) + tk.MustExec(`set @a="0051-12-23", @b="none", @c="none", @d="none", @e=-32757, @f=-32757`) + tk.MustQuery(`execute stmt using @a,@b,@c,@d,@e,@f`).Check(testkit.Rows()) + tk.MustExec(`set @a="9795-01-10", @b="aa", @c="aa", @d="aa", @e=31928, @f=31928`) + tk.MustQuery(`execute stmt using @a,@b,@c,@d,@e,@f`).Check(testkit.Rows()) +} + +func TestIssue42150(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("CREATE TABLE `t1` (`c_int` int(11) NOT NULL, `c_str` varchar(40) CHARACTER SET utf8 COLLATE utf8_bin DEFAULT NULL, `c_datetime` datetime DEFAULT NULL, `c_timestamp` timestamp NULL DEFAULT NULL, `c_double` double DEFAULT NULL, `c_decimal` decimal(12,6) DEFAULT NULL, `c_enum` enum('blue','green','red','yellow','white','orange','purple') NOT NULL, PRIMARY KEY (`c_int`,`c_enum`) /*T![clustered_index] CLUSTERED */, KEY `c_decimal` (`c_decimal`), UNIQUE KEY `c_datetime` (`c_datetime`), UNIQUE KEY `c_timestamp` (`c_timestamp`)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;") + tk.MustExec("create table t (a int, b int, primary key(a), key(b))") + + tk.MustExec("prepare stmt from 'select c_enum from t1'") + tk.MustExec("execute stmt;") + tk.MustExec("execute stmt;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("prepare st from 'select a from t use index(b)'") + tk.MustExec("execute st") + tk.MustExec("execute st") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) +} diff --git a/planner/core/plan_cost_ver1.go b/planner/core/plan_cost_ver1.go index f7459c70fb01a..ecbdc0e654fb0 100644 --- a/planner/core/plan_cost_ver1.go +++ b/planner/core/plan_cost_ver1.go @@ -77,7 +77,7 @@ func (p *PhysicalSelection) getPlanCostVer1(taskType property.TaskType, option * switch taskType { case property.RootTaskType, property.MppTaskType: cpuFactor = p.ctx.GetSessionVars().GetCPUFactor() - case property.CopSingleReadTaskType, property.CopDoubleReadTaskType: + case property.CopSingleReadTaskType, property.CopMultiReadTaskType: cpuFactor = p.ctx.GetSessionVars().GetCopCPUFactor() default: return 0, errors.Errorf("unknown task type %v", taskType) @@ -181,7 +181,7 @@ func (p *PhysicalIndexLookUpReader) getPlanCostVer1(taskType property.TaskType, p.planCost = 0 // child's cost for _, child := range []PhysicalPlan{p.indexPlan, p.tablePlan} { - childCost, err := child.getPlanCostVer1(property.CopDoubleReadTaskType, option) + childCost, err := child.getPlanCostVer1(property.CopMultiReadTaskType, option) if err != nil { return 0, err } @@ -194,7 +194,7 @@ func (p *PhysicalIndexLookUpReader) getPlanCostVer1(taskType property.TaskType, for tmp = p.tablePlan; len(tmp.Children()) > 0; tmp = tmp.Children()[0] { } ts := tmp.(*PhysicalTableScan) - tblCost, err := ts.getPlanCostVer1(property.CopDoubleReadTaskType, option) + tblCost, err := ts.getPlanCostVer1(property.CopMultiReadTaskType, option) if err != nil { return 0, err } @@ -1027,7 +1027,7 @@ func (p *PhysicalHashAgg) getPlanCostVer1(taskType property.TaskType, option *Pl switch taskType { case property.RootTaskType: p.planCost += p.GetCost(statsCnt, true, false, costFlag) - case property.CopSingleReadTaskType, property.CopDoubleReadTaskType: + case property.CopSingleReadTaskType, property.CopMultiReadTaskType: p.planCost += p.GetCost(statsCnt, false, false, costFlag) case property.MppTaskType: p.planCost += p.GetCost(statsCnt, false, true, costFlag) diff --git a/planner/core/plan_cost_ver2.go b/planner/core/plan_cost_ver2.go index dfecbd0761fff..1f1569e0c14df 100644 --- a/planner/core/plan_cost_ver2.go +++ b/planner/core/plan_cost_ver2.go @@ -242,7 +242,7 @@ func (p *PhysicalIndexLookUpReader) getPlanCostVer2(taskType property.TaskType, // index-side indexNetCost := netCostVer2(option, indexRows, indexRowSize, netFactor) - indexChildCost, err := p.indexPlan.getPlanCostVer2(property.CopDoubleReadTaskType, option) + indexChildCost, err := p.indexPlan.getPlanCostVer2(property.CopMultiReadTaskType, option) if err != nil { return zeroCostVer2, err } @@ -250,7 +250,7 @@ func (p *PhysicalIndexLookUpReader) getPlanCostVer2(taskType property.TaskType, // table-side tableNetCost := netCostVer2(option, tableRows, tableRowSize, netFactor) - tableChildCost, err := p.tablePlan.getPlanCostVer2(property.CopDoubleReadTaskType, option) + tableChildCost, err := p.tablePlan.getPlanCostVer2(property.CopMultiReadTaskType, option) if err != nil { return zeroCostVer2, err } diff --git a/planner/core/plan_cost_ver2_test.go b/planner/core/plan_cost_ver2_test.go index cb47b6324b987..f09897e54be39 100644 --- a/planner/core/plan_cost_ver2_test.go +++ b/planner/core/plan_cost_ver2_test.go @@ -136,6 +136,7 @@ func TestCostModelVer2(t *testing.T) { } func TestCostModelShowFormula(t *testing.T) { + t.Skip() store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") @@ -157,6 +158,7 @@ func TestCostModelShowFormula(t *testing.T) { } func TestCostModelVer2ScanRowSize(t *testing.T) { + t.Skip() store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index f592b9824c255..02ef060990349 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -1696,7 +1696,7 @@ func TestParamMarker4FastPlan(t *testing.T) { tk.MustQuery("execute stmt using @a2, @a3;").Sort().Check(testkit.Rows("1 7", "1 8", "1 9")) tk.MustExec(`set @a2=4, @a3=2`) tk.MustQuery("execute stmt using @a2, @a3;").Sort().Check(testkit.Rows("1 10", "1 7", "1 8")) - tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) } func TestIssue29565(t *testing.T) { @@ -1959,7 +1959,7 @@ func TestPlanCachePointGetAndTableDual(t *testing.T) { tk.MustQuery("execute s2 using @a2, @a2").Check(testkit.Rows()) tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) tk.MustQuery("execute s2 using @a2, @b2").Check(testkit.Rows("1 7777")) - tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) tk.MustExec("create table t3(c1 int, c2 int, c3 int, unique key(c1), key(c2))") tk.MustExec("insert into t3 values(2,1,1)") @@ -2008,7 +2008,7 @@ func TestIssue26873(t *testing.T) { tk.MustQuery("execute stmt using @p").Check(testkit.Rows()) tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) tk.MustQuery("execute stmt using @p").Check(testkit.Rows()) - tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) } func TestIssue29511(t *testing.T) { diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index d188854686491..7a13bf281b6e0 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -357,21 +357,30 @@ func (s *partitionProcessor) processHashPartition(ds *DataSource, pi *model.Part // listPartitionPruner uses to prune partition for list partition. type listPartitionPruner struct { *partitionProcessor - ctx sessionctx.Context - pi *model.PartitionInfo - partitionNames []model.CIStr - colIDToUniqueID map[int64]int64 - fullRange map[int]struct{} - listPrune *tables.ForListPruning + ctx sessionctx.Context + pi *model.PartitionInfo + partitionNames []model.CIStr + fullRange map[int]struct{} + listPrune *tables.ForListPruning } func newListPartitionPruner(ctx sessionctx.Context, tbl table.Table, partitionNames []model.CIStr, - s *partitionProcessor, conds []expression.Expression, pruneList *tables.ForListPruning) *listPartitionPruner { - colIDToUniqueID := make(map[int64]int64) - for _, cond := range conds { - condCols := expression.ExtractColumns(cond) - for _, c := range condCols { - colIDToUniqueID[c.ID] = c.UniqueID + s *partitionProcessor, conds []expression.Expression, pruneList *tables.ForListPruning, columns []*expression.Column) *listPartitionPruner { + pruneList = pruneList.Clone() + for i := range pruneList.PruneExprCols { + for j := range columns { + if columns[j].ID == pruneList.PruneExprCols[i].ID { + pruneList.PruneExprCols[i].UniqueID = columns[j].UniqueID + break + } + } + } + for i := range pruneList.ColPrunes { + for j := range columns { + if columns[j].ID == pruneList.ColPrunes[i].ExprCol.ID { + pruneList.ColPrunes[i].ExprCol.UniqueID = columns[j].UniqueID + break + } } } fullRange := make(map[int]struct{}) @@ -381,7 +390,6 @@ func newListPartitionPruner(ctx sessionctx.Context, tbl table.Table, partitionNa ctx: ctx, pi: tbl.Meta().Partition, partitionNames: partitionNames, - colIDToUniqueID: colIDToUniqueID, fullRange: fullRange, listPrune: pruneList, } @@ -524,9 +532,6 @@ func (l *listPartitionPruner) detachCondAndBuildRange(conds []expression.Express colLen := make([]int, 0, len(exprCols)) for _, c := range exprCols { c = c.Clone().(*expression.Column) - if uniqueID, ok := l.colIDToUniqueID[c.ID]; ok { - c.UniqueID = uniqueID - } cols = append(cols, c) colLen = append(colLen, types.UnspecifiedLength) } @@ -592,14 +597,14 @@ func (l *listPartitionPruner) findUsedListPartitions(conds []expression.Expressi } func (s *partitionProcessor) findUsedListPartitions(ctx sessionctx.Context, tbl table.Table, partitionNames []model.CIStr, - conds []expression.Expression) ([]int, error) { + conds []expression.Expression, columns []*expression.Column) ([]int, error) { pi := tbl.Meta().Partition partExpr, err := tbl.(partitionTable).PartitionExpr() if err != nil { return nil, err } - listPruner := newListPartitionPruner(ctx, tbl, partitionNames, s, conds, partExpr.ForListPruning) + listPruner := newListPartitionPruner(ctx, tbl, partitionNames, s, conds, partExpr.ForListPruning, columns) var used map[int]struct{} if partExpr.ForListPruning.ColPrunes == nil { used, err = listPruner.findUsedListPartitions(conds) @@ -622,8 +627,8 @@ func (s *partitionProcessor) findUsedListPartitions(ctx sessionctx.Context, tbl } func (s *partitionProcessor) pruneListPartition(ctx sessionctx.Context, tbl table.Table, partitionNames []model.CIStr, - conds []expression.Expression) ([]int, error) { - used, err := s.findUsedListPartitions(ctx, tbl, partitionNames, conds) + conds []expression.Expression, columns []*expression.Column) ([]int, error) { + used, err := s.findUsedListPartitions(ctx, tbl, partitionNames, conds, columns) if err != nil { return nil, err } @@ -870,7 +875,7 @@ func (s *partitionProcessor) processRangePartition(ds *DataSource, pi *model.Par } func (s *partitionProcessor) processListPartition(ds *DataSource, pi *model.PartitionInfo, opt *logicalOptimizeOp) (LogicalPlan, error) { - used, err := s.pruneListPartition(ds.SCtx(), ds.table, ds.partitionNames, ds.allConds) + used, err := s.pruneListPartition(ds.SCtx(), ds.table, ds.partitionNames, ds.allConds, ds.TblCols) if err != nil { return nil, err } diff --git a/planner/core/stats.go b/planner/core/stats.go index 9d2b1f59898b0..c3eec89fcc0f3 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -811,6 +811,11 @@ func (ds *DataSource) buildIndexMergeOrPath( path.TableFilters = nil } } + + // Keep this filter as a part of table filters for safety if it has any parameter. + if expression.MaybeOverOptimized4PlanCache(ds.ctx, filters[current:current+1]) { + shouldKeepCurrentFilter = true + } if shouldKeepCurrentFilter { indexMergePath.TableFilters = append(indexMergePath.TableFilters, filters[current]) } @@ -898,6 +903,11 @@ func (ds *DataSource) generateIndexMergeAndPaths(normalPathCnt int) *util.Access } } + // Keep these partial filters as a part of table filters for safety if there is any parameter. + if expression.MaybeOverOptimized4PlanCache(ds.ctx, partialFilters) { + dedupedFinalFilters = append(dedupedFinalFilters, partialFilters...) + } + // 3. Estimate the row count after partial paths. sel, _, err := ds.tableStats.HistColl.Selectivity(ds.ctx, partialFilters, nil) if err != nil { diff --git a/planner/core/task.go b/planner/core/task.go index 627b49d72596f..9198bcdaec0bd 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -152,7 +152,9 @@ func (t *copTask) finishIndexPlan() { return } t.indexPlanFinished = true - if t.tablePlan != nil { + // index merge case is specially handled for now. + // We need a elegant way to solve the stats of index merge in this case. + if t.tablePlan != nil && t.indexPlan != nil { ts := t.tablePlan.(*PhysicalTableScan) originStats := ts.stats ts.stats = t.indexPlan.statsInfo() @@ -817,22 +819,43 @@ func (p *PhysicalLimit) attach2Task(tasks ...task) task { t := tasks[0].copy() sunk := false if cop, ok := t.(*copTask); ok { - // For double read which requires order being kept, the limit cannot be pushed down to the table side, - // because handles would be reordered before being sent to table scan. - if (!cop.keepOrder || !cop.indexPlanFinished || cop.indexPlan == nil) && len(cop.rootTaskConds) == 0 { - // When limit is pushed down, we should remove its offset. - newCount := p.Offset + p.Count - childProfile := cop.plan().statsInfo() - // Strictly speaking, for the row count of stats, we should multiply newCount with "regionNum", - // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. - stats := deriveLimitStats(childProfile, float64(newCount)) - pushedDownLimit := PhysicalLimit{Count: newCount}.Init(p.ctx, stats, p.blockOffset) - cop = attachPlan2Task(pushedDownLimit, cop).(*copTask) - // Don't use clone() so that Limit and its children share the same schema. Otherwise the virtual generated column may not be resolved right. - pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) + if len(cop.idxMergePartPlans) == 0 { + // For double read which requires order being kept, the limit cannot be pushed down to the table side, + // because handles would be reordered before being sent to table scan. + if (!cop.keepOrder || !cop.indexPlanFinished || cop.indexPlan == nil) && len(cop.rootTaskConds) == 0 { + // When limit is pushed down, we should remove its offset. + newCount := p.Offset + p.Count + childProfile := cop.plan().statsInfo() + // Strictly speaking, for the row count of stats, we should multiply newCount with "regionNum", + // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. + stats := deriveLimitStats(childProfile, float64(newCount)) + pushedDownLimit := PhysicalLimit{Count: newCount}.Init(p.ctx, stats, p.blockOffset) + cop = attachPlan2Task(pushedDownLimit, cop).(*copTask) + // Don't use clone() so that Limit and its children share the same schema. Otherwise the virtual generated column may not be resolved right. + pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) + } + t = cop.convertToRootTask(p.ctx) + sunk = p.sinkIntoIndexLookUp(t) + } else if !cop.idxMergeIsIntersection { + // We only support push part of the order prop down to index merge case. + if !cop.keepOrder && !cop.indexPlanFinished && len(cop.rootTaskConds) == 0 { + newCount := p.Offset + p.Count + limitChildren := make([]PhysicalPlan, 0, len(cop.idxMergePartPlans)) + for _, partialScan := range cop.idxMergePartPlans { + childProfile := partialScan.statsInfo() + stats := deriveLimitStats(childProfile, float64(newCount)) + pushedDownLimit := PhysicalLimit{Count: newCount}.Init(p.ctx, stats, p.blockOffset) + pushedDownLimit.SetChildren(partialScan) + pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) + limitChildren = append(limitChildren, pushedDownLimit) + } + cop.idxMergePartPlans = limitChildren + } + t = cop.convertToRootTask(p.ctx) + } else { + // Whatever the remained case is, we directly convert to it to root task. + t = cop.convertToRootTask(p.ctx) } - t = cop.convertToRootTask(p.ctx) - sunk = p.sinkIntoIndexLookUp(t) } else if mpp, ok := t.(*mppTask); ok { newCount := p.Offset + p.Count childProfile := mpp.plan().statsInfo() @@ -931,6 +954,12 @@ func (p *PhysicalTopN) getPushedDownTopN(childPlan PhysicalPlan) *PhysicalTopN { // // there's no prefix index column. func (p *PhysicalTopN) canPushToIndexPlan(indexPlan PhysicalPlan, byItemCols []*expression.Column) bool { + // If we call canPushToIndexPlan and there's no index plan, we should go into the index merge case. + // Index merge case is specially handled for now. So we directly return false here. + // So we directly return false. + if indexPlan == nil { + return false + } schema := indexPlan.Schema() for _, col := range byItemCols { pos := schema.ColumnIndex(col) @@ -977,7 +1006,14 @@ func (p *PhysicalTopN) canPushDownToTiKV(copTask *copTask) bool { if len(copTask.rootTaskConds) != 0 { return false } - if p.containVirtualColumn(copTask.plan().Schema().Columns) { + if len(copTask.idxMergePartPlans) > 0 && !copTask.indexPlanFinished { + for _, partialPlan := range copTask.idxMergePartPlans { + if p.containVirtualColumn(partialPlan.Schema().Columns) { + return false + } + } + } + if copTask.plan() != nil && p.containVirtualColumn(copTask.plan().Schema().Columns) { return false } return true @@ -1001,8 +1037,8 @@ func (p *PhysicalTopN) attach2Task(tasks ...task) task { cols = append(cols, expression.ExtractColumns(item.Expr)...) } needPushDown := len(cols) > 0 - if copTask, ok := t.(*copTask); ok && needPushDown && p.canPushDownToTiKV(copTask) { - newTask, changed := p.pushTopNDownToDynamicPartition(copTask) + if copTask, ok := t.(*copTask); ok && needPushDown && p.canPushDownToTiKV(copTask) && len(copTask.rootTaskConds) == 0 { + newTask, changed := p.pushPartialTopNDownToCop(copTask) if changed { return newTask } @@ -1013,6 +1049,7 @@ func (p *PhysicalTopN) attach2Task(tasks ...task) task { pushedDownTopN = p.getPushedDownTopN(copTask.indexPlan) copTask.indexPlan = pushedDownTopN } else { + // It works for both normal index scan and index merge scan. copTask.finishIndexPlan() pushedDownTopN = p.getPushedDownTopN(copTask.tablePlan) copTask.tablePlan = pushedDownTopN @@ -1025,11 +1062,11 @@ func (p *PhysicalTopN) attach2Task(tasks ...task) task { return attachPlan2Task(p, rootTask) } -// pushTopNDownToDynamicPartition is a temp solution for partition table. It actually does the same thing as DataSource's isMatchProp. +// pushPartialTopNDownToCop is a temp solution for partition table and index merge. It actually does the same thing as DataSource's isMatchProp. // We need to support a more enhanced read strategy in the execution phase. So that we can achieve Limit(TiDB)->Reader(TiDB)->Limit(TiKV/TiFlash)->Scan(TiKV/TiFlash). // Before that is done, we use this logic to provide a way to keep the order property when reading from TiKV, so that we can use the orderliness of index to speed up the query. // Here we can change the execution plan to TopN(TiDB)->Reader(TiDB)->Limit(TiKV)->Scan(TiKV).(TiFlash is not supported). -func (p *PhysicalTopN) pushTopNDownToDynamicPartition(copTsk *copTask) (task, bool) { +func (p *PhysicalTopN) pushPartialTopNDownToCop(copTsk *copTask) (task, bool) { if copTsk.getStoreType() != kv.TiKV { return nil, false } @@ -1045,36 +1082,16 @@ func (p *PhysicalTopN) pushTopNDownToDynamicPartition(copTsk *copTask) (task, bo if !allSameOrder { return nil, false } - checkIndexMatchProp := func(idxCols []*expression.Column, idxColLens []int, constColsByCond []bool, colsProp *property.PhysicalProperty) bool { - // If the number of the by-items is bigger than the index columns. We cannot push down since it must not keep order. - if len(idxCols) < len(colsProp.SortItems) { - return false - } - idxPos := 0 - for _, byItem := range colsProp.SortItems { - found := false - for ; idxPos < len(idxCols); idxPos++ { - if idxColLens[idxPos] == types.UnspecifiedLength && idxCols[idxPos].Equal(p.SCtx(), byItem.Col) { - found = true - idxPos++ - break - } - if len(constColsByCond) == 0 || idxPos > len(constColsByCond) || !constColsByCond[idxPos] { - found = false - break - } - } - if !found { - return false - } - } - return true + if len(copTsk.idxMergePartPlans) > 0 && copTsk.idxMergeIsIntersection { + return nil, false } var ( - idxScan *PhysicalIndexScan - tblScan *PhysicalTableScan - tblInfo *model.TableInfo - err error + idxScan *PhysicalIndexScan + tblScan *PhysicalTableScan + partialScans []PhysicalPlan + clonedPartialPlan []PhysicalPlan + tblInfo *model.TableInfo + err error ) if copTsk.indexPlan != nil { copTsk.indexPlan, err = copTsk.indexPlan.Clone() @@ -1100,31 +1117,65 @@ func (p *PhysicalTopN) pushTopNDownToDynamicPartition(copTsk *copTask) (task, bo tblScan = finalTblScanPlan.(*PhysicalTableScan) tblInfo = tblScan.Table } + if len(copTsk.idxMergePartPlans) > 0 { + // calculate selectivities for each partial plan in advance and clone partial plans since we may modify their stats later. + partialScans = make([]PhysicalPlan, 0, len(copTsk.idxMergePartPlans)) + for _, scan := range copTsk.idxMergePartPlans { + clonedScan, err := scan.Clone() + if err != nil { + return nil, false + } + clonedPartialPlan = append(clonedPartialPlan, clonedScan) + finalScan := clonedScan + for len(finalScan.Children()) > 0 { + finalScan = finalScan.Children()[0] + } + partialScans = append(partialScans, finalScan) + } + } pi := tblInfo.GetPartitionInfo() - if pi == nil { + if pi == nil && len(copTsk.idxMergePartPlans) == 0 { return nil, false } - if pi.Type == model.PartitionTypeList { + if pi != nil && pi.Type == model.PartitionTypeList { return nil, false } if !copTsk.indexPlanFinished { // If indexPlan side isn't finished, there's no selection on the table side. - - propMatched := checkIndexMatchProp(idxScan.IdxCols, idxScan.IdxColLens, idxScan.constColsByCond, colsProp) - if !propMatched { - return nil, false + if len(copTsk.idxMergePartPlans) > 0 { + // Deal with index merge case. + propMatched := p.checkSubScans(colsProp, isDesc, partialScans...) + if !propMatched { + // If there's one used index cannot match the prop. + return nil, false + } + newCopSubPlans := p.addPartialLimitForSubScans(clonedPartialPlan, partialScans) + copTsk.idxMergePartPlans = newCopSubPlans + clonedTblScan, err := copTsk.tablePlan.Clone() + if err != nil { + return nil, false + } + clonedTblScan.statsInfo().ScaleByExpectCnt(float64(p.Count+p.Offset) * float64(len(copTsk.idxMergePartPlans))) + copTsk.tablePlan = clonedTblScan + copTsk.indexPlanFinished = true + } else { + // The normal index scan cases.(single read and double read) + propMatched := p.checkOrderPropForSubIndexScan(idxScan.IdxCols, idxScan.IdxColLens, idxScan.constColsByCond, colsProp) + if !propMatched { + return nil, false + } + idxScan.Desc = isDesc + childProfile := copTsk.plan().statsInfo() + newCount := p.Offset + p.Count + stats := deriveLimitStats(childProfile, float64(newCount)) + pushedLimit := PhysicalLimit{ + Count: newCount, + }.Init(p.SCtx(), stats, p.SelectBlockOffset()) + pushedLimit.SetSchema(copTsk.indexPlan.Schema()) + copTsk = attachPlan2Task(pushedLimit, copTsk).(*copTask) } - idxScan.Desc = isDesc - childProfile := copTsk.plan().statsInfo() - newCount := p.Offset + p.Count - stats := deriveLimitStats(childProfile, float64(newCount)) - pushedLimit := PhysicalLimit{ - Count: newCount, - }.Init(p.SCtx(), stats, p.SelectBlockOffset()) - pushedLimit.SetSchema(copTsk.indexPlan.Schema()) - copTsk = attachPlan2Task(pushedLimit, copTsk).(*copTask) } else if copTsk.indexPlan == nil { if tblScan.HandleCols == nil { return nil, false @@ -1137,7 +1188,7 @@ func (p *PhysicalTopN) pushTopNDownToDynamicPartition(copTsk *copTask) (task, bo } } else { idxCols, idxColLens := expression.IndexInfo2PrefixCols(tblScan.Columns, tblScan.Schema().Columns, tables.FindPrimaryIndex(tblScan.Table)) - matched := checkIndexMatchProp(idxCols, idxColLens, nil, colsProp) + matched := p.checkOrderPropForSubIndexScan(idxCols, idxColLens, nil, colsProp) if !matched { return nil, false } @@ -1161,10 +1212,91 @@ func (p *PhysicalTopN) pushTopNDownToDynamicPartition(copTsk *copTask) (task, bo return attachPlan2Task(p, rootTask), true } +// checkOrderPropForSubIndexScan checks whether these index columns can meet the specified order property. +func (p *PhysicalTopN) checkOrderPropForSubIndexScan(idxCols []*expression.Column, idxColLens []int, constColsByCond []bool, colsProp *property.PhysicalProperty) bool { + // If the number of the by-items is bigger than the index columns. We cannot push down since it must not keep order. + if len(idxCols) < len(colsProp.SortItems) { + return false + } + idxPos := 0 + for _, byItem := range colsProp.SortItems { + found := false + for ; idxPos < len(idxCols); idxPos++ { + if idxColLens[idxPos] == types.UnspecifiedLength && idxCols[idxPos].Equal(p.SCtx(), byItem.Col) { + found = true + idxPos++ + break + } + if len(constColsByCond) == 0 || idxPos > len(constColsByCond) || !constColsByCond[idxPos] { + found = false + break + } + } + if !found { + return false + } + } + return true +} + +// checkSubScans checks whether all these Scans can meet the specified order property. +func (p *PhysicalTopN) checkSubScans(colsProp *property.PhysicalProperty, isDesc bool, scans ...PhysicalPlan) bool { + for _, scan := range scans { + switch x := scan.(type) { + case *PhysicalIndexScan: + propMatched := p.checkOrderPropForSubIndexScan(x.IdxCols, x.IdxColLens, x.constColsByCond, colsProp) + if !propMatched { + return false + } + x.KeepOrder = true + x.Desc = isDesc + case *PhysicalTableScan: + if x.HandleCols == nil { + return false + } + + if x.HandleCols.IsInt() { + pk := x.HandleCols.GetCol(0) + if len(colsProp.SortItems) != 1 || !colsProp.SortItems[0].Col.Equal(p.SCtx(), pk) { + return false + } + } else { + idxCols, idxColLens := expression.IndexInfo2PrefixCols(x.Columns, x.Schema().Columns, tables.FindPrimaryIndex(x.Table)) + matched := p.checkOrderPropForSubIndexScan(idxCols, idxColLens, nil, colsProp) + if !matched { + return false + } + } + x.KeepOrder = true + x.Desc = isDesc + default: + return false + } + } + // Return true when all sub index plan matched. + return true +} + +func (p *PhysicalTopN) addPartialLimitForSubScans(copSubPlans []PhysicalPlan, finalPartialScans []PhysicalPlan) []PhysicalPlan { + limitAddedPlan := make([]PhysicalPlan, 0, len(copSubPlans)) + for _, copSubPlan := range copSubPlans { + childProfile := copSubPlan.statsInfo() + newCount := p.Offset + p.Count + stats := deriveLimitStats(childProfile, float64(newCount)) + pushedLimit := PhysicalLimit{ + Count: newCount, + }.Init(p.SCtx(), stats, p.SelectBlockOffset()) + pushedLimit.SetSchema(copSubPlan.Schema()) + pushedLimit.SetChildren(copSubPlan) + limitAddedPlan = append(limitAddedPlan, pushedLimit) + } + return limitAddedPlan +} + func (p *PhysicalProjection) attach2Task(tasks ...task) task { t := tasks[0].copy() if cop, ok := t.(*copTask); ok { - if len(cop.rootTaskConds) == 0 && expression.CanExprsPushDown(p.ctx.GetSessionVars().StmtCtx, p.Exprs, p.ctx.GetClient(), cop.getStoreType()) { + if (len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0) && expression.CanExprsPushDown(p.ctx.GetSessionVars().StmtCtx, p.Exprs, p.ctx.GetClient(), cop.getStoreType()) { copTask := attachPlan2Task(p, cop) return copTask } @@ -1781,7 +1913,7 @@ func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task { // We should not push agg down across double read, since the data of second read is ordered by handle instead of index. // The `extraHandleCol` is added if the double read needs to keep order. So we just use it to decided // whether the following plan is double read with order reserved. - if cop.extraHandleCol != nil || len(cop.rootTaskConds) > 0 { + if cop.extraHandleCol != nil || len(cop.rootTaskConds) > 0 || len(cop.idxMergePartPlans) > 0 { t = cop.convertToRootTask(p.ctx) attachPlan2Task(p, t) } else { @@ -2028,7 +2160,7 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task { t := tasks[0].copy() final := p if cop, ok := t.(*copTask); ok { - if len(cop.rootTaskConds) == 0 { + if len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0 { copTaskType := cop.getStoreType() partialAgg, finalAgg := p.newPartialAggregate(copTaskType, false) if finalAgg != nil { diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 14c04c6cfb0ab..5e87ab77f9ea7 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -2096,7 +2096,7 @@ " └─TableRowIDScan_14(Probe) 0.00 186.61 cop[tikv] table:t keep order:false" ], "Warnings": [ - "Note 1105 [idx_b] remain after pruning paths for t given Prop{SortItems: [], TaskTp: copDoubleReadTask}" + "Note 1105 [idx_b] remain after pruning paths for t given Prop{SortItems: [], TaskTp: copMultiReadTask}" ] }, { @@ -2108,7 +2108,7 @@ "└─TableRowIDScan_11(Probe) 0.00 186.61 cop[tikv] table:t keep order:false" ], "Warnings": [ - "Note 1105 [idx_b] remain after pruning paths for t given Prop{SortItems: [], TaskTp: copDoubleReadTask}" + "Note 1105 [idx_b] remain after pruning paths for t given Prop{SortItems: [], TaskTp: copMultiReadTask}" ] } ] diff --git a/planner/core/testdata/plan_suite_in.json b/planner/core/testdata/plan_suite_in.json index c38250c802454..e33f3fb93869e 100644 --- a/planner/core/testdata/plan_suite_in.json +++ b/planner/core/testdata/plan_suite_in.json @@ -1126,5 +1126,19 @@ "select a, count(*) from t group by a -- shouldn't be rewritten", "select sum(a) from t -- sum shouldn't be rewritten" ] + }, + { + "name": "TestIndexMergeOrderPushDown", + "cases": [ + "select * from t where a = 1 or b = 1 order by c limit 2", + "select * from t where a = 1 or b in (1, 2, 3) order by c limit 2", + "select * from t where a in (1, 2, 3) or b = 1 order by c limit 2", + "select * from t where a in (1, 2, 3) or b in (1, 2, 3) order by c limit 2", + "select * from t where (a = 1 and c = 2) or (b = 1) order by c limit 2", + "select * from t where (a = 1 and c = 2) or b in (1, 2, 3) order by c limit 2", + "select * from t where (a = 1 and c = 2) or (b in (1, 2, 3) and c = 3) order by c limit 2", + "select * from t where (a = 1 or b = 2) and c = 3 order by c limit 2", + "select * from t where (a = 1 or b = 2) and c in (1, 2, 3) order by c limit 2" + ] } ] diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index b3a7664b2b2fd..5d55764635ffe 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -6709,5 +6709,119 @@ "Warning": null } ] + }, + { + "Name": "TestIndexMergeOrderPushDown", + "Cases": [ + { + "SQL": "select * from t where a = 1 or b = 1 order by c limit 2", + "Plan": [ + "TopN 2.00 root test.t.c, offset:0, count:2", + "└─IndexMerge 19.99 root type: union", + " ├─Limit(Build) 2.00 cop[tikv] offset:0, count:2", + " │ └─IndexRangeScan 10.00 cop[tikv] table:t, index:idx(a, c) range:[1,1], keep order:true, stats:pseudo", + " ├─Limit(Build) 2.00 cop[tikv] offset:0, count:2", + " │ └─IndexRangeScan 10.00 cop[tikv] table:t, index:idx2(b, c) range:[1,1], keep order:true, stats:pseudo", + " └─TableRowIDScan(Probe) 19.99 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where a = 1 or b in (1, 2, 3) order by c limit 2", + "Plan": [ + "TopN 2.00 root test.t.c, offset:0, count:2", + "└─IndexMerge 2.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:idx(a, c) range:[1,1], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 30.00 cop[tikv] table:t, index:idx2(b, c) range:[1,1], [2,2], [3,3], keep order:false, stats:pseudo", + " └─TopN(Probe) 2.00 cop[tikv] test.t.c, offset:0, count:2", + " └─TableRowIDScan 39.97 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where a in (1, 2, 3) or b = 1 order by c limit 2", + "Plan": [ + "TopN 2.00 root test.t.c, offset:0, count:2", + "└─IndexMerge 2.00 root type: union", + " ├─IndexRangeScan(Build) 30.00 cop[tikv] table:t, index:idx(a, c) range:[1,1], [2,2], [3,3], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:idx2(b, c) range:[1,1], keep order:false, stats:pseudo", + " └─TopN(Probe) 2.00 cop[tikv] test.t.c, offset:0, count:2", + " └─TableRowIDScan 39.97 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where a in (1, 2, 3) or b in (1, 2, 3) order by c limit 2", + "Plan": [ + "TopN 2.00 root test.t.c, offset:0, count:2", + "└─IndexMerge 2.00 root type: union", + " ├─IndexRangeScan(Build) 30.00 cop[tikv] table:t, index:idx(a, c) range:[1,1], [2,2], [3,3], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 30.00 cop[tikv] table:t, index:idx2(b, c) range:[1,1], [2,2], [3,3], keep order:false, stats:pseudo", + " └─TopN(Probe) 2.00 cop[tikv] test.t.c, offset:0, count:2", + " └─TableRowIDScan 59.91 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where (a = 1 and c = 2) or (b = 1) order by c limit 2", + "Plan": [ + "TopN 2.00 root test.t.c, offset:0, count:2", + "└─IndexMerge 10.10 root type: union", + " ├─Limit(Build) 0.10 cop[tikv] offset:0, count:2", + " │ └─IndexRangeScan 0.10 cop[tikv] table:t, index:idx(a, c) range:[1 2,1 2], keep order:true, stats:pseudo", + " ├─Limit(Build) 2.00 cop[tikv] offset:0, count:2", + " │ └─IndexRangeScan 10.00 cop[tikv] table:t, index:idx2(b, c) range:[1,1], keep order:true, stats:pseudo", + " └─TableRowIDScan(Probe) 10.10 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where (a = 1 and c = 2) or b in (1, 2, 3) order by c limit 2", + "Plan": [ + "TopN 2.00 root test.t.c, offset:0, count:2", + "└─IndexMerge 2.00 root type: union", + " ├─IndexRangeScan(Build) 0.10 cop[tikv] table:t, index:idx(a, c) range:[1 2,1 2], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 30.00 cop[tikv] table:t, index:idx2(b, c) range:[1,1], [2,2], [3,3], keep order:false, stats:pseudo", + " └─TopN(Probe) 2.00 cop[tikv] test.t.c, offset:0, count:2", + " └─TableRowIDScan 30.10 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where (a = 1 and c = 2) or (b in (1, 2, 3) and c = 3) order by c limit 2", + "Plan": [ + "TopN 0.40 root test.t.c, offset:0, count:2", + "└─IndexMerge 0.40 root type: union", + " ├─IndexRangeScan(Build) 0.10 cop[tikv] table:t, index:idx(a, c) range:[1 2,1 2], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 0.30 cop[tikv] table:t, index:idx2(b, c) range:[1 3,1 3], [2 3,2 3], [3 3,3 3], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 0.40 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where (a = 1 or b = 2) and c = 3 order by c limit 2", + "Plan": [ + "TopN 0.02 root test.t.c, offset:0, count:2", + "└─IndexMerge 0.02 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:idx(a, c) range:[1,1], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:idx2(b, c) range:[2,2], keep order:false, stats:pseudo", + " └─Selection(Probe) 0.02 cop[tikv] eq(test.t.c, 3)", + " └─TableRowIDScan 19.99 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + }, + { + "SQL": "select * from t where (a = 1 or b = 2) and c in (1, 2, 3) order by c limit 2", + "Plan": [ + "TopN 0.06 root test.t.c, offset:0, count:2", + "└─IndexMerge 0.06 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:idx(a, c) range:[1,1], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:idx2(b, c) range:[2,2], keep order:false, stats:pseudo", + " └─Selection(Probe) 0.06 cop[tikv] in(test.t.c, 1, 2, 3)", + " └─TableRowIDScan 19.99 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Warning": null + } + ] } ] diff --git a/planner/property/physical_property.go b/planner/property/physical_property.go index 60994d05b57ef..e92b74adbbe4a 100644 --- a/planner/property/physical_property.go +++ b/planner/property/physical_property.go @@ -30,7 +30,7 @@ import ( // wholeTaskTypes records all possible kinds of task that a plan can return. For Agg, TopN and Limit, we will try to get // these tasks one by one. -var wholeTaskTypes = []TaskType{CopSingleReadTaskType, CopDoubleReadTaskType, RootTaskType} +var wholeTaskTypes = []TaskType{CopSingleReadTaskType, CopMultiReadTaskType, RootTaskType} // SortItem wraps the column and its order. type SortItem struct { diff --git a/planner/property/task_type.go b/planner/property/task_type.go index a4c16d4a51d2e..8a424b26284b7 100644 --- a/planner/property/task_type.go +++ b/planner/property/task_type.go @@ -25,9 +25,9 @@ const ( // executed in the coprocessor layer. CopSingleReadTaskType - // CopDoubleReadTaskType stands for the a IndexLookup tasks executed in the + // CopMultiReadTaskType stands for the a IndexLookup tasks executed in the // coprocessor layer. - CopDoubleReadTaskType + CopMultiReadTaskType // MppTaskType stands for task that would run on Mpp nodes, currently meaning the tiflash node. MppTaskType @@ -40,8 +40,8 @@ func (t TaskType) String() string { return "rootTask" case CopSingleReadTaskType: return "copSingleReadTask" - case CopDoubleReadTaskType: - return "copDoubleReadTask" + case CopMultiReadTaskType: + return "copMultiReadTask" case MppTaskType: return "mppTask" } diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 4d83db171b7f6..e45bb2c4bb4c3 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -2024,7 +2024,7 @@ func TestSecurityEnhancedModeSysVars(t *testing.T) { tk.MustQuery(`SHOW VARIABLES LIKE 'tidb_force_priority'`).Check(testkit.Rows("tidb_force_priority NO_PRIORITY")) tk.MustQuery(`SELECT COUNT(*) FROM information_schema.variables_info WHERE variable_name = 'tidb_top_sql_max_meta_count'`).Check(testkit.Rows("1")) tk.MustQuery(`SELECT COUNT(*) FROM performance_schema.session_variables WHERE variable_name = 'tidb_top_sql_max_meta_count'`).Check(testkit.Rows("1")) - tk.MustQuery(`SHOW GLOBAL VARIABLES LIKE 'tidb_enable_telemetry'`).Check(testkit.Rows("tidb_enable_telemetry ON")) + tk.MustQuery(`SHOW GLOBAL VARIABLES LIKE 'tidb_enable_telemetry'`).Check(testkit.Rows("tidb_enable_telemetry OFF")) tk.MustQuery(`SELECT COUNT(*) FROM information_schema.variables_info WHERE variable_name = 'tidb_enable_telemetry'`).Check(testkit.Rows("1")) tk.MustQuery(`SELECT COUNT(*) FROM performance_schema.session_variables WHERE variable_name = 'tidb_enable_telemetry'`).Check(testkit.Rows("1")) diff --git a/server/BUILD.bazel b/server/BUILD.bazel index 54bc5ffcd1325..56b3741259f81 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -194,7 +194,6 @@ go_test( "//util/plancodec", "//util/resourcegrouptag", "//util/rowcodec", - "//util/sqlexec", "//util/topsql", "//util/topsql/collector", "//util/topsql/collector/mock", diff --git a/server/conn.go b/server/conn.go index d60c4042874a1..843c89e175f2d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -2405,34 +2405,12 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool // fetchSize, the desired number of rows to be fetched each time when client uses cursor. func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet, serverStatus uint16, fetchSize int) error { fetchedRows := rs.GetFetchedRows() - // if fetchedRows is not enough, getting data from recordSet. - // NOTE: chunk should not be allocated from the allocator - // the allocator will reset every statement - // but it maybe stored in the result set among statements - // ref https://github.com/pingcap/tidb/blob/7fc6ebbda4ddf84c0ba801ca7ebb636b934168cf/server/conn_stmt.go#L233-L239 - // Here server.tidbResultSet implements Next method. - req := rs.NewChunk(nil) - for len(fetchedRows) < fetchSize { - if err := rs.Next(ctx, req); err != nil { - return err - } - rowCount := req.NumRows() - if rowCount == 0 { - break - } - // filling fetchedRows with chunk - for i := 0; i < rowCount; i++ { - fetchedRows = append(fetchedRows, req.GetRow(i)) - } - req = chunk.Renew(req, cc.ctx.GetSessionVars().MaxChunkSize) - } // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, // and close ResultSet. if len(fetchedRows) == 0 { serverStatus &^= mysql.ServerStatusCursorExists serverStatus |= mysql.ServerStatusLastRowSend - terror.Call(rs.Close) return cc.writeEOF(ctx, serverStatus) } diff --git a/server/conn_stmt.go b/server/conn_stmt.go index cf2f9f2aa6e86..177fdecfd3d72 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -56,6 +56,7 @@ import ( "github.com/pingcap/tidb/sessiontxn" storeerr "github.com/pingcap/tidb/store/driver/error" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/topsql" @@ -158,7 +159,10 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeScrollable", nil) } - if !useCursor { + if useCursor { + cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) + defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) + } else { // not using streaming ,can reuse chunk cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc) } @@ -251,7 +255,8 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{} // The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried. // Currently the first return value is used to fallback to TiKV when TiFlash is down. func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []expression.Expression, useCursor bool) (bool, error) { - prepStmt, err := (&cc.ctx).GetSessionVars().GetPreparedStmtByID(uint32(stmt.ID())) + vars := (&cc.ctx).GetSessionVars() + prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID())) if err != nil { return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) } @@ -274,11 +279,14 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) } if rs == nil { + if useCursor { + vars.SetStatusFlag(mysql.ServerStatusCursorExists, false) + } return false, cc.writeOK(ctx) } // since there are multiple implementations of ResultSet (the rs might be wrapped), we have to unwrap the rs before // casting it to *tidbResultSet. - if result, ok := unwrapResultSet(rs).(*tidbResultSet); ok { + if result, ok := rs.(*tidbResultSet); ok { if planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt); ok { result.preparedStmt = planCacheStmt } @@ -290,12 +298,31 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm if useCursor { cc.initResultEncoder(ctx) defer cc.rsEncoder.clean() - // fix https://github.com/pingcap/tidb/issues/39447. we need to hold the start-ts here because the process info - // will be set to sleep after fetch returned. - if pi := cc.ctx.ShowProcess(); pi != nil && pi.ProtectedTSList != nil && pi.CurTxnStartTS > 0 { - unhold := pi.HoldTS(pi.CurTxnStartTS) - rs = &rsWithHooks{ResultSet: rs, onClosed: unhold} + // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read + // the rows directly to avoid running executor and accessing shared params/variables in the session + // NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command + // but the rows are still needed in the following FETCH command. + // + // TODO: trace the memory used here + chk := rs.NewChunk(nil) + var rows []chunk.Row + for { + if err = rs.Next(ctx, chk); err != nil { + return false, err + } + rowCount := chk.NumRows() + if rowCount == 0 { + break + } + // filling fetchedRows with chunk + for i := 0; i < rowCount; i++ { + row := chk.GetRow(i) + rows = append(rows, row) + } + chk = chunk.Renew(chk, vars.MaxChunkSize) } + rs.StoreFetchedRows(rows) + stmt.StoreResultSet(rs) if err = cc.writeColumnInfo(rs.Columns()); err != nil { return false, err @@ -303,11 +330,22 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm if cl, ok := rs.(fetchNotifier); ok { cl.OnFetchReturned() } + + // as the `Next` of `ResultSet` will never be called, all rows have been cached inside it. We could close this + // `ResultSet`. + err = rs.Close() + if err != nil { + return false, err + } + + stmt.SetCursorActive(true) + // explicitly flush columnInfo to client. - err = cc.writeEOF(ctx, cc.ctx.Status()|mysql.ServerStatusCursorExists) + err = cc.writeEOF(ctx, cc.ctx.Status()) if err != nil { return false, err } + return false, cc.flush(ctx) } defer terror.Call(rs.Close) @@ -326,6 +364,8 @@ const ( func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) { cc.ctx.GetSessionVars().StartTime = time.Now() cc.ctx.GetSessionVars().ClearAlloc(nil, false) + cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) + defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) stmtID, fetchSize, err := parseStmtFetchCmd(data) if err != nil { @@ -354,10 +394,19 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID)) } - _, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status()|mysql.ServerStatusCursorExists, int(fetchSize)) + sendingEOF := false + // if the `fetchedRows` are empty before writing result, we could say the `FETCH` command will send EOF + if len(rs.GetFetchedRows()) == 0 { + sendingEOF = true + } + _, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status(), int(fetchSize)) if err != nil { return errors.Annotate(err, cc.preparedStmt2String(stmtID)) } + if sendingEOF { + stmt.SetCursorActive(false) + } + return nil } @@ -695,6 +744,7 @@ func (cc *clientConn) handleStmtClose(data []byte) (err error) { if stmt != nil { return stmt.Close() } + return } diff --git a/server/conn_stmt_test.go b/server/conn_stmt_test.go index 2e60fc1085332..dff61b203bf5e 100644 --- a/server/conn_stmt_test.go +++ b/server/conn_stmt_test.go @@ -15,6 +15,7 @@ package server import ( + "bytes" "context" "encoding/binary" "testing" @@ -255,7 +256,7 @@ func TestParseStmtFetchCmd(t *testing.T) { } } -func TestCursorReadHoldTS(t *testing.T) { +func TestCursorExistsFlag(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) srv := CreateMockServer(t, store) srv.SetDomain(dom) @@ -263,7 +264,10 @@ func TestCursorReadHoldTS(t *testing.T) { appendUint32 := binary.LittleEndian.AppendUint32 ctx := context.Background() - c := CreateMockConn(t, srv) + c := CreateMockConn(t, srv).(*mockConn) + out := new(bytes.Buffer) + c.pkt.bufWriter.Reset(out) + c.capability |= mysql.ClientDeprecateEOF | mysql.ClientProtocol41 tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -271,72 +275,100 @@ func TestCursorReadHoldTS(t *testing.T) { tk.MustExec("insert into t values (1), (2), (3), (4), (5), (6), (7), (8)") tk.MustQuery("select count(*) from t").Check(testkit.Rows("8")) + getLastStatus := func() uint16 { + raw := out.Bytes() + return binary.LittleEndian.Uint16(raw[len(raw)-4 : len(raw)-2]) + } + stmt, _, _, err := c.Context().Prepare("select * from t") require.NoError(t, err) - require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0)) - // should hold ts after executing stmt with cursor require.NoError(t, c.Dispatch(ctx, append( appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, ))) - ts := tk.Session().ShowProcess().GetMinStartTS(0) - require.Positive(t, ts) - // should unhold ts when result set exhausted + require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // fetch first 5 require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) - require.Equal(t, ts, tk.Session().ShowProcess().GetMinStartTS(0)) - require.Equal(t, ts, srv.GetMinStartTS(0)) + require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // COM_QUERY during fetch + require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...))) + require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // fetch last 3 require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) - require.Equal(t, ts, tk.Session().ShowProcess().GetMinStartTS(0)) - require.Equal(t, ts, srv.GetMinStartTS(0)) + require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // final fetch with no row retured + // (tidb doesn't unset cursor-exists flag in the previous response like mysql, one more fetch is needed) require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) - require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0)) + require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) + require.True(t, getLastStatus()&mysql.ServerStatusLastRowSend > 0) - // should hold ts after executing stmt with cursor - require.NoError(t, c.Dispatch(ctx, append( - appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), - mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, - ))) - require.Positive(t, tk.Session().ShowProcess().GetMinStartTS(0)) - // should unhold ts when stmt reset - require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtReset}, uint32(stmt.ID())))) - require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0)) + // COM_QUERY after fetch + require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...))) + require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) +} - // should hold ts after executing stmt with cursor - require.NoError(t, c.Dispatch(ctx, append( - appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), - mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, - ))) - require.Positive(t, tk.Session().ShowProcess().GetMinStartTS(0)) - // should unhold ts when stmt closed - require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtClose}, uint32(stmt.ID())))) - require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0)) +func TestCursorWithParams(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id_1 int, id_2 int)") + tk.MustExec("insert into t values (1, 1), (1, 2)") - // create another 2 stmts and execute them - stmt1, _, _, err := c.Context().Prepare("select * from t") + stmt1, _, _, err := c.Context().Prepare("select * from t where id_1 = ? and id_2 = ?") require.NoError(t, err) + stmt2, _, _, err := c.Context().Prepare("select * from t where id_1 = ?") + require.NoError(t, err) + + // `execute stmt1 using 1,2` with cursor require.NoError(t, c.Dispatch(ctx, append( appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt1.ID())), mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + 0x0, 0x1, 0x3, 0x0, 0x3, 0x0, + 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, ))) - ts1 := tk.Session().ShowProcess().GetMinStartTS(0) - require.Positive(t, ts1) - stmt2, _, _, err := c.Context().Prepare("select * from t") - require.NoError(t, err) + rows := c.Context().stmts[stmt1.ID()].GetResultSet().GetFetchedRows() + require.Len(t, rows, 1) + require.Equal(t, int64(1), rows[0].GetInt64(0)) + require.Equal(t, int64(2), rows[0].GetInt64(1)) + + // `execute stmt2 using 1` with cursor require.NoError(t, c.Dispatch(ctx, append( appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt2.ID())), mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + 0x0, 0x1, 0x3, 0x0, + 0x1, 0x0, 0x0, 0x0, ))) - ts2 := tk.Session().ShowProcess().GetMinStartTS(ts1) - require.Positive(t, ts2) + rows = c.Context().stmts[stmt2.ID()].GetResultSet().GetFetchedRows() + require.Len(t, rows, 2) + require.Equal(t, int64(1), rows[0].GetInt64(0)) + require.Equal(t, int64(1), rows[0].GetInt64(1)) + require.Equal(t, int64(1), rows[1].GetInt64(0)) + require.Equal(t, int64(2), rows[1].GetInt64(1)) - require.Less(t, ts1, ts2) - require.Equal(t, ts1, srv.GetMinStartTS(0)) - require.Equal(t, ts2, srv.GetMinStartTS(ts1)) - require.Zero(t, srv.GetMinStartTS(ts2)) + // fetch stmt2 with fetch size 256 + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt2.ID())), + 0x0, 0x1, 0x0, 0x0, + ))) - // should unhold all when session closed - c.Close() - require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0)) - require.Zero(t, srv.GetMinStartTS(0)) + // fetch stmt1 with fetch size 256, as it has more params, if we fetch the result at the first execute command, it + // will panic because the params have been overwritten and is not long enough. + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt1.ID())), + 0x0, 0x1, 0x0, 0x0, + ))) } diff --git a/server/conn_test.go b/server/conn_test.go index eab2d52449118..fa3b9d5317a96 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -524,7 +524,7 @@ func TestDispatchClientProtocol41(t *testing.T) { com: mysql.ComStmtFetch, in: []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, err: nil, - out: []byte{0x5, 0x0, 0x0, 0x9, 0xfe, 0x0, 0x0, 0x82, 0x0}, + out: []byte{0x5, 0x0, 0x0, 0x9, 0xfe, 0x0, 0x0, 0x42, 0x0}, }, { com: mysql.ComStmtReset, @@ -917,8 +917,7 @@ func TestTiFlashFallback(t *testing.T) { tk.MustQuery("show warnings").Check(testkit.Rows("Error 9012 TiFlash server timeout")) // test COM_STMT_FETCH (cursor mode) - require.NoError(t, cc.handleStmtExecute(ctx, []byte{0x1, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0})) - require.Error(t, cc.handleStmtFetch(ctx, []byte{0x1, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0})) + require.Error(t, cc.handleStmtExecute(ctx, []byte{0x1, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0})) tk.MustExec("set @@tidb_allow_fallback_to_tikv=''") require.Error(t, cc.handleStmtExecute(ctx, []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0})) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0")) diff --git a/server/driver.go b/server/driver.go index a4a59f4cba2b5..7fdebdd2739ff 100644 --- a/server/driver.go +++ b/server/driver.go @@ -63,6 +63,12 @@ type PreparedStatement interface { // Close closes the statement. Close() error + + // GetCursorActive returns whether the statement has active cursor + GetCursorActive() bool + + // SetCursorActive sets whether the statement has active cursor + SetCursorActive(active bool) } // ResultSet is the result set of an query. diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 7b25a998d618b..5b3980f0b8d37 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -66,8 +66,13 @@ type TiDBStatement struct { boundParams [][]byte paramsType []byte ctx *TiDBContext - rs ResultSet - sql string + // this result set should have been closed before stored here. Only the `fetchedRows` are used here. This field is + // not moved out to reuse the logic inside functions `writeResultSet...` + // TODO: move the `fetchedRows` into the statement, and remove the `ResultSet` from statement. + rs ResultSet + sql string + + hasActiveCursor bool } // ID implements PreparedStatement ID method. @@ -147,12 +152,7 @@ func (ts *TiDBStatement) Reset() { for i := range ts.boundParams { ts.boundParams[i] = nil } - - // closing previous ResultSet if it exists - if ts.rs != nil { - terror.Call(ts.rs.Close) - ts.rs = nil - } + ts.hasActiveCursor = false } // Close implements PreparedStatement Close method. @@ -183,14 +183,19 @@ func (ts *TiDBStatement) Close() error { ts.ctx.GetSessionVars().RemovePreparedStmt(ts.id) } delete(ts.ctx.stmts, int(ts.id)) - - // close ResultSet associated with this statement - if ts.rs != nil { - terror.Call(ts.rs.Close) - } return nil } +// GetCursorActive implements PreparedStatement GetCursorActive method. +func (ts *TiDBStatement) GetCursorActive() bool { + return ts.hasActiveCursor +} + +// SetCursorActive implements PreparedStatement SetCursorActive method. +func (ts *TiDBStatement) SetCursorActive(fetchEnd bool) { + ts.hasActiveCursor = fetchEnd +} + // OpenCtx implements IDriver. func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error) { se, err := session.CreateSession(qd.store) @@ -356,8 +361,8 @@ func (tc *TiDBContext) EncodeSessionStates(ctx context.Context, sctx sessionctx. return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("prepared statements have bound params") } } - if rs := stmt.GetResultSet(); rs != nil && !rs.IsClosed() { - return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("prepared statements have open result sets") + if stmt.GetCursorActive() { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("prepared statements have unfetched rows") } preparedStmtInfo.ParamTypes = stmt.GetParamsType() } @@ -480,46 +485,6 @@ func (trs *tidbResultSet) Columns() []*ColumnInfo { return trs.columns } -// rsWithHooks wraps a ResultSet with some hooks (currently only onClosed). -type rsWithHooks struct { - ResultSet - onClosed func() -} - -// Close implements ResultSet#Close -func (rs *rsWithHooks) Close() error { - closed := rs.IsClosed() - err := rs.ResultSet.Close() - if !closed && rs.onClosed != nil { - rs.onClosed() - } - return err -} - -// OnFetchReturned implements fetchNotifier#OnFetchReturned -func (rs *rsWithHooks) OnFetchReturned() { - if impl, ok := rs.ResultSet.(fetchNotifier); ok { - impl.OnFetchReturned() - } -} - -// Unwrap returns the underlying result set -func (rs *rsWithHooks) Unwrap() ResultSet { - return rs.ResultSet -} - -// unwrapResultSet likes errors.Cause but for ResultSet -func unwrapResultSet(rs ResultSet) ResultSet { - var unRS ResultSet - if u, ok := rs.(interface{ Unwrap() ResultSet }); ok { - unRS = u.Unwrap() - } - if unRS == nil { - return rs - } - return unwrapResultSet(unRS) -} - func convertColumnInfo(fld *ast.ResultField) (ci *ColumnInfo) { ci = &ColumnInfo{ Name: fld.ColumnAsName.O, diff --git a/server/driver_tidb_test.go b/server/driver_tidb_test.go index b56632937e078..b5f7e670aded0 100644 --- a/server/driver_tidb_test.go +++ b/server/driver_tidb_test.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/sqlexec" "github.com/stretchr/testify/require" ) @@ -96,27 +95,3 @@ func TestConvertColumnInfo(t *testing.T) { colInfo = convertColumnInfo(&resultField) require.Equal(t, uint32(4), colInfo.ColumnLength) } - -func TestRSWithHooks(t *testing.T) { - closeCount := 0 - rs := &rsWithHooks{ - ResultSet: &tidbResultSet{recordSet: new(sqlexec.SimpleRecordSet)}, - onClosed: func() { closeCount++ }, - } - require.Equal(t, 0, closeCount) - rs.Close() - require.Equal(t, 1, closeCount) - rs.Close() - require.Equal(t, 1, closeCount) -} - -func TestUnwrapRS(t *testing.T) { - var nilRS ResultSet - require.Nil(t, unwrapResultSet(nilRS)) - rs0 := new(tidbResultSet) - rs1 := &rsWithHooks{ResultSet: rs0} - rs2 := &rsWithHooks{ResultSet: rs1} - for _, rs := range []ResultSet{rs0, rs1, rs2} { - require.Equal(t, rs0, unwrapResultSet(rs)) - } -} diff --git a/server/http_handler.go b/server/http_handler.go index 2855186b9cdfa..39a7928baae94 100644 --- a/server/http_handler.go +++ b/server/http_handler.go @@ -228,7 +228,7 @@ func (t *tikvHandlerTool) getMvccByIdxValue(idx table.Index, values url.Values, return nil, err } idxData := &helper.MvccKV{Key: strings.ToUpper(hex.EncodeToString(encodedKey)), RegionID: regionID, Value: data} - tablecodec.IndexKey2TempIndexKey(idx.Meta().ID, encodedKey) + tablecodec.IndexKey2TempIndexKey(encodedKey) data, err = t.GetMvccByEncodedKey(encodedKey) if err != nil { return nil, err diff --git a/server/server.go b/server/server.go index 11cf28300fb38..521e4da46984b 100644 --- a/server/server.go +++ b/server/server.go @@ -1002,37 +1002,3 @@ func (s *Server) KillNonFlashbackClusterConn() { s.Kill(id, false) } } - -// GetMinStartTS implements SessionManager interface. -func (s *Server) GetMinStartTS(lowerBound uint64) (ts uint64) { - // sys processes - if s.dom != nil { - for _, pi := range s.dom.SysProcTracker().GetSysProcessList() { - if thisTS := pi.GetMinStartTS(lowerBound); thisTS > lowerBound && (thisTS < ts || ts == 0) { - ts = thisTS - } - } - } - // user sessions - func() { - s.rwlock.RLock() - defer s.rwlock.RUnlock() - for _, client := range s.clients { - if thisTS := client.ctx.ShowProcess().GetMinStartTS(lowerBound); thisTS > lowerBound && (thisTS < ts || ts == 0) { - ts = thisTS - } - } - }() - // internal sessions - func() { - s.sessionMapMutex.Lock() - defer s.sessionMapMutex.Unlock() - analyzeProcID := util.GetAutoAnalyzeProcID(s.ServerID) - for se := range s.internalSessions { - if thisTS, processInfoID := session.GetStartTSFromSession(se); processInfoID != analyzeProcID && thisTS > lowerBound && (thisTS < ts || ts == 0) { - ts = thisTS - } - } - }() - return -} diff --git a/session/session.go b/session/session.go index c713f5bcce764..ca12271364b9b 100644 --- a/session/session.go +++ b/session/session.go @@ -1602,7 +1602,6 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu OOMAlarmVariablesInfo: s.getOomAlarmVariablesInfo(), MaxExecutionTime: maxExecutionTime, RedactSQL: s.sessionVars.EnableRedactLog, - ProtectedTSList: &s.sessionVars.ProtectedTSList, } oldPi := s.ShowProcess() if p == nil { diff --git a/session/txn.go b/session/txn.go index 552f81e88fc77..ead377492c84a 100644 --- a/session/txn.go +++ b/session/txn.go @@ -541,6 +541,16 @@ func keyNeedToLock(k, v []byte, flags kv.KeyFlags) bool { return true } + if tablecodec.IsTempIndexKey(k) { + tmpVal, err := tablecodec.DecodeTempIndexValue(v) + if err != nil { + logutil.BgLogger().Warn("decode temp index value failed", zap.Error(err)) + return false + } + current := tmpVal.Current() + return current.Handle != nil || tablecodec.IndexKVIsUnique(current.Value) + } + return tablecodec.IndexKVIsUnique(v) } diff --git a/sessionctx/variable/BUILD.bazel b/sessionctx/variable/BUILD.bazel index 17ff6a6bc482c..f69d4b04ffec5 100644 --- a/sessionctx/variable/BUILD.bazel +++ b/sessionctx/variable/BUILD.bazel @@ -94,10 +94,12 @@ go_test( "//parser/mysql", "//parser/terror", "//planner/core", + "//sessionctx/sessionstates", "//sessionctx/stmtctx", "//testkit", "//testkit/testsetup", "//types", + "//util", "//util/chunk", "//util/execdetails", "//util/gctuner", diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 2b7c17bbce0ab..64e2cdc841898 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1322,9 +1322,6 @@ type SessionVars struct { // StoreBatchSize indicates the batch size limit of store batch, set this field to 0 to disable store batch. StoreBatchSize int - - // ProtectedTSList holds a list of timestamps that should delay GC. - ProtectedTSList protectedTSList } // GetNewChunkWithCapacity Attempt to request memory from the chunk pool @@ -3154,53 +3151,3 @@ func (s *SessionVars) GetRelatedTableForMDL() *sync.Map { func (s *SessionVars) EnableForceInlineCTE() bool { return s.enableForceInlineCTE } - -// protectedTSList implements util/processinfo#ProtectedTSList -type protectedTSList struct { - sync.Mutex - items map[uint64]int -} - -// HoldTS holds the timestamp to prevent its data from being GCed. -func (lst *protectedTSList) HoldTS(ts uint64) (unhold func()) { - lst.Lock() - if lst.items == nil { - lst.items = map[uint64]int{} - } - lst.items[ts] += 1 - lst.Unlock() - var once sync.Once - return func() { - once.Do(func() { - lst.Lock() - if lst.items != nil { - if lst.items[ts] > 1 { - lst.items[ts] -= 1 - } else { - delete(lst.items, ts) - } - } - lst.Unlock() - }) - } -} - -// GetMinProtectedTS returns the minimum protected timestamp that greater than `lowerBound` (0 if no such one). -func (lst *protectedTSList) GetMinProtectedTS(lowerBound uint64) (ts uint64) { - lst.Lock() - for k, v := range lst.items { - if v > 0 && k > lowerBound && (k < ts || ts == 0) { - ts = k - } - } - lst.Unlock() - return -} - -// Size returns the number of protected timestamps (exported for test). -func (lst *protectedTSList) Size() (size int) { - lst.Lock() - size = len(lst.items) - lst.Unlock() - return -} diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index edeb756a1f707..b0a2e76fa08ef 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -15,6 +15,8 @@ package variable_test import ( + "context" + "strconv" "sync" "testing" "time" @@ -25,10 +27,12 @@ import ( "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/mysql" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/sessionctx/sessionstates" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" + util2 "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/mock" @@ -487,56 +491,34 @@ func TestGetReuseChunk(t *testing.T) { require.Nil(t, sessVars.ChunkPool.Alloc) } -func TestPretectedTSList(t *testing.T) { - lst := &variable.NewSessionVars(nil).ProtectedTSList - - // empty set - require.Equal(t, uint64(0), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(1)) - require.Equal(t, 0, lst.Size()) - - // hold 1 - unhold1 := lst.HoldTS(1) - require.Equal(t, uint64(1), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(1)) - - // hold 2 twice - unhold2a := lst.HoldTS(2) - unhold2b := lst.HoldTS(2) - require.Equal(t, uint64(1), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(2), lst.GetMinProtectedTS(1)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(2)) - require.Equal(t, 2, lst.Size()) - - // unhold 2a - unhold2a() - require.Equal(t, uint64(1), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(2), lst.GetMinProtectedTS(1)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(2)) - require.Equal(t, 2, lst.Size()) - // unhold 2a again - unhold2a() - require.Equal(t, uint64(1), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(2), lst.GetMinProtectedTS(1)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(2)) - require.Equal(t, 2, lst.Size()) - - // unhold 1 - unhold1() - require.Equal(t, uint64(2), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(2), lst.GetMinProtectedTS(1)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(2)) - require.Equal(t, 1, lst.Size()) - - // unhold 2b - unhold2b() - require.Equal(t, uint64(0), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(1)) - require.Equal(t, 0, lst.Size()) - - // unhold 2b again - unhold2b() - require.Equal(t, uint64(0), lst.GetMinProtectedTS(0)) - require.Equal(t, uint64(0), lst.GetMinProtectedTS(1)) - require.Equal(t, 0, lst.Size()) +func TestUserVarConcurrently(t *testing.T) { + sv := variable.NewSessionVars(nil) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + var wg util2.WaitGroupWrapper + wg.Run(func() { + for i := 0; ; i++ { + select { + case <-time.After(time.Millisecond): + name := strconv.Itoa(i) + sv.SetUserVarVal(name, types.Datum{}) + sv.GetUserVarVal(name) + case <-ctx.Done(): + return + } + } + }) + wg.Run(func() { + for { + select { + case <-time.After(time.Millisecond): + var states sessionstates.SessionStates + require.NoError(t, sv.EncodeSessionStates(ctx, &states)) + require.NoError(t, sv.DecodeSessionStates(ctx, &states)) + case <-ctx.Done(): + return + } + } + }) + wg.Wait() + cancel() } diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 05748a73b06b5..d2217ed9d361f 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -1029,7 +1029,7 @@ const ( DefTiDBRestrictedReadOnly = false DefTiDBSuperReadOnly = false DefTiDBShardAllocateStep = math.MaxInt64 - DefTiDBEnableTelemetry = true + DefTiDBEnableTelemetry = false DefTiDBEnableParallelApply = false DefTiDBEnableAmendPessimisticTxn = false DefTiDBPartitionPruneMode = "dynamic" diff --git a/store/copr/coprocessor.go b/store/copr/coprocessor.go index a95416bb264a9..1aee41b4e08c0 100644 --- a/store/copr/coprocessor.go +++ b/store/copr/coprocessor.go @@ -308,7 +308,15 @@ func buildCopTasks(bo *Backoffer, cache *RegionCache, ranges *KeyRanges, req *kv return buildTiDBMemCopTasks(ranges, req) } + hints := req.FixedRowCountHint rangesLen := ranges.Len() + // Since ranges from multi partitions may be pushed to one cop iterator, + // the relationship between hints and ranges is probably broken. + // But multi-partitioned ranges and hints should not exist in the same time, + // this check only guarantees there is no out-of-range use. + if len(hints) != rangesLen { + hints = nil + } // TODO(youjiali1995): is there any request type that needn't be splitted by buckets? locs, err := cache.SplitKeyRangesByBuckets(bo, ranges) @@ -345,7 +353,7 @@ func buildCopTasks(bo *Backoffer, cache *RegionCache, ranges *KeyRanges, req *kv nextI := mathutil.Min(i+rangesPerTask, rLen) hint := -1 // calculate the row count hint - if req.FixedRowCountHint != nil { + if hints != nil { startKey, endKey := loc.Ranges.At(i).StartKey, loc.Ranges.At(nextI-1).EndKey // move to the previous range if startKey of current range is lower than endKey of previous location. // In the following example, task1 will move origRangeIdx to region(i, z). @@ -362,7 +370,7 @@ func buildCopTasks(bo *Backoffer, cache *RegionCache, ranges *KeyRanges, req *kv origRangeIdx = nextOrigRangeIdx break } - hint += req.FixedRowCountHint[nextOrigRangeIdx] + hint += hints[nextOrigRangeIdx] } } task := &copTask{ @@ -1160,13 +1168,13 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *tikv.R if err != nil { return remains, err } - return worker.handleBatchRemainsOnErr(bo, remains, resp.pbResp.BatchResponses, task, ch) + return worker.handleBatchRemainsOnErr(bo, rpcCtx, remains, resp.pbResp.BatchResponses, task, ch) } if lockErr := resp.pbResp.GetLocked(); lockErr != nil { if err := worker.handleLockErr(bo, lockErr, task); err != nil { return nil, err } - return worker.handleBatchRemainsOnErr(bo, []*copTask{task}, resp.pbResp.BatchResponses, task, ch) + return worker.handleBatchRemainsOnErr(bo, rpcCtx, []*copTask{task}, resp.pbResp.BatchResponses, task, ch) } if otherErr := resp.pbResp.GetOtherError(); otherErr != "" { err := errors.Errorf("other error: %s", otherErr) @@ -1255,16 +1263,16 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *tikv.R } batchResps := resp.pbResp.BatchResponses worker.sendToRespCh(resp, ch, true) - return worker.handleBatchCopResponse(bo, batchResps, task.batchTaskList, ch) + return worker.handleBatchCopResponse(bo, rpcCtx, batchResps, task.batchTaskList, ch) } -func (worker *copIteratorWorker) handleBatchRemainsOnErr(bo *Backoffer, remains []*copTask, batchResp []*coprocessor.StoreBatchTaskResponse, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { +func (worker *copIteratorWorker) handleBatchRemainsOnErr(bo *Backoffer, rpcCtx *tikv.RPCContext, remains []*copTask, batchResp []*coprocessor.StoreBatchTaskResponse, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { if len(task.batchTaskList) == 0 { return remains, nil } batchedTasks := task.batchTaskList task.batchTaskList = nil - batchedRemains, err := worker.handleBatchCopResponse(bo, batchResp, batchedTasks, ch) + batchedRemains, err := worker.handleBatchCopResponse(bo, rpcCtx, batchResp, batchedTasks, ch) if err != nil { return nil, err } @@ -1272,10 +1280,17 @@ func (worker *copIteratorWorker) handleBatchRemainsOnErr(bo *Backoffer, remains } // handle the batched cop response. -func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, batchResps []*coprocessor.StoreBatchTaskResponse, tasks map[uint64]*batchedCopTask, ch chan<- *copResponse) ([]*copTask, error) { +func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, rpcCtx *tikv.RPCContext, batchResps []*coprocessor.StoreBatchTaskResponse, tasks map[uint64]*batchedCopTask, ch chan<- *copResponse) ([]*copTask, error) { if len(tasks) == 0 { return nil, nil } + // need Addr for recording details. + var dummyRPCCtx *tikv.RPCContext + if rpcCtx != nil { + dummyRPCCtx = &tikv.RPCContext{ + Addr: rpcCtx.Addr, + } + } var remainTasks []*copTask for _, batchResp := range batchResps { batchedTask, ok := tasks[batchResp.GetTaskId()] @@ -1284,7 +1299,8 @@ func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, batchResp } resp := &copResponse{ pbResp: &coprocessor.Response{ - Data: batchResp.Data, + Data: batchResp.Data, + ExecDetailsV2: batchResp.ExecDetailsV2, }, } task := batchedTask.task @@ -1331,8 +1347,9 @@ func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, batchResp } return nil, errors.Trace(err) } + worker.handleCollectExecutionInfo(bo, dummyRPCCtx, resp) // TODO: check OOM - worker.sendToRespCh(resp, ch, false) + worker.sendToRespCh(resp, ch, true) } return remainTasks, nil } diff --git a/table/tables/index.go b/table/tables/index.go index 0f856447464e0..b8a164a9b2ece 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -177,7 +177,12 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue if !distinct || skipCheck || opt.Untouched { val := idxVal - if keyIsTempIdxKey && !opt.Untouched { // Untouched key-values never occur in the storage. + if opt.Untouched && (keyIsTempIdxKey || len(tempKey) > 0) { + // Untouched key-values never occur in the storage and the temp index is not public. + // It is unnecessary to write the untouched temp index key-values. + return nil, nil + } + if keyIsTempIdxKey { tempVal := tablecodec.TempIndexValueElem{Value: idxVal, KeyVer: keyVer, Distinct: distinct} val = tempVal.Encode(nil) } @@ -186,10 +191,8 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue return nil, err } if len(tempKey) > 0 { - if !opt.Untouched { // Untouched key-values never occur in the storage. - tempVal := tablecodec.TempIndexValueElem{Value: idxVal, KeyVer: keyVer, Distinct: distinct} - val = tempVal.Encode(nil) - } + tempVal := tablecodec.TempIndexValueElem{Value: idxVal, KeyVer: keyVer, Distinct: distinct} + val = tempVal.Encode(nil) err = txn.GetMemBuffer().Set(tempKey, val) if err != nil { return nil, err @@ -229,7 +232,7 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue } var tempIdxVal tablecodec.TempIndexValue if len(value) > 0 && keyIsTempIdxKey { - tempIdxVal, err = tablecodec.DecodeTempIndexValue(value, c.tblInfo.IsCommonHandle) + tempIdxVal, err = tablecodec.DecodeTempIndexValue(value) if err != nil { return nil, err } @@ -392,7 +395,7 @@ func GenTempIdxKeyByState(indexInfo *model.IndexInfo, indexKey kv.Key) (key, tem return indexKey, nil, TempIndexKeyTypeNone case model.BackfillStateRunning: // Write to the temporary index. - tablecodec.IndexKey2TempIndexKey(indexInfo.ID, indexKey) + tablecodec.IndexKey2TempIndexKey(indexKey) if indexInfo.State == model.StateDeleteOnly { return nil, indexKey, TempIndexKeyTypeDelete } @@ -401,7 +404,7 @@ func GenTempIdxKeyByState(indexInfo *model.IndexInfo, indexKey kv.Key) (key, tem // Double write tmp := make([]byte, len(indexKey)) copy(tmp, indexKey) - tablecodec.IndexKey2TempIndexKey(indexInfo.ID, tmp) + tablecodec.IndexKey2TempIndexKey(tmp) return indexKey, tmp, TempIndexKeyTypeMerge } } @@ -432,8 +435,8 @@ func (c *index) Exist(sc *stmtctx.StatementContext, txn kv.Transaction, indexedV // FetchDuplicatedHandle is used to find the duplicated row's handle for a given unique index key. func FetchDuplicatedHandle(ctx context.Context, key kv.Key, distinct bool, txn kv.Transaction, tableID int64, isCommon bool) (foundKey bool, dupHandle kv.Handle, err error) { - if isTemp, originIdxID := tablecodec.CheckTempIndexKey(key); isTemp { - return fetchDuplicatedHandleForTempIndexKey(ctx, key, distinct, txn, tableID, originIdxID, isCommon) + if tablecodec.IsTempIndexKey(key) { + return fetchDuplicatedHandleForTempIndexKey(ctx, key, distinct, txn, tableID, isCommon) } // The index key is not from temp index. val, err := getKeyInTxn(ctx, txn, key) @@ -448,14 +451,14 @@ func FetchDuplicatedHandle(ctx context.Context, key kv.Key, distinct bool, } func fetchDuplicatedHandleForTempIndexKey(ctx context.Context, tempKey kv.Key, distinct bool, - txn kv.Transaction, tableID, idxID int64, isCommon bool) (foundKey bool, dupHandle kv.Handle, err error) { + txn kv.Transaction, tableID int64, isCommon bool) (foundKey bool, dupHandle kv.Handle, err error) { tempRawVal, err := getKeyInTxn(ctx, txn, tempKey) if err != nil { return false, nil, err } if tempRawVal == nil { originKey := tempKey.Clone() - tablecodec.TempIndexKey2IndexKey(idxID, originKey) + tablecodec.TempIndexKey2IndexKey(originKey) originVal, err := getKeyInTxn(ctx, txn, originKey) if err != nil || originVal == nil { return false, nil, err @@ -469,14 +472,14 @@ func fetchDuplicatedHandleForTempIndexKey(ctx context.Context, tempKey kv.Key, d } return false, nil, nil } - tempVal, err := tablecodec.DecodeTempIndexValue(tempRawVal, isCommon) + tempVal, err := tablecodec.DecodeTempIndexValue(tempRawVal) if err != nil { return false, nil, err } curElem := tempVal.Current() if curElem.Delete { originKey := tempKey.Clone() - tablecodec.TempIndexKey2IndexKey(idxID, originKey) + tablecodec.TempIndexKey2IndexKey(originKey) originVal, err := getKeyInTxn(ctx, txn, originKey) if err != nil || originVal == nil { return false, nil, err diff --git a/table/tables/mutation_checker.go b/table/tables/mutation_checker.go index 2229bfbb9d138..81225dcda1663 100644 --- a/table/tables/mutation_checker.go +++ b/table/tables/mutation_checker.go @@ -159,7 +159,7 @@ func checkHandleConsistency(rowInsertion mutation, indexMutations []mutation, in continue } var tempIdxVal tablecodec.TempIndexValue - tempIdxVal, err = tablecodec.DecodeTempIndexValue(m.value, tblInfo.IsCommonHandle) + tempIdxVal, err = tablecodec.DecodeTempIndexValue(m.value) if err != nil { return err } @@ -171,7 +171,7 @@ func checkHandleConsistency(rowInsertion mutation, indexMutations []mutation, in continue } orgKey = append(orgKey, m.key...) - tablecodec.TempIndexKey2IndexKey(idxID, orgKey) + tablecodec.TempIndexKey2IndexKey(orgKey) indexHandle, err = tablecodec.DecodeIndexHandle(orgKey, value, len(indexInfo.Columns)) } else { indexHandle, err = tablecodec.DecodeIndexHandle(m.key, m.value, len(indexInfo.Columns)) @@ -227,7 +227,7 @@ func checkIndexKeys( // We never commit the untouched key values to the storage. Skip this check. continue } - tmpVal, err := tablecodec.DecodeTempIndexValue(m.value, t.Meta().IsCommonHandle) + tmpVal, err := tablecodec.DecodeTempIndexValue(m.value) if err != nil { return err } diff --git a/table/tables/partition.go b/table/tables/partition.go index 6a0b315b856e9..5318dbe258d03 100644 --- a/table/tables/partition.go +++ b/table/tables/partition.go @@ -627,6 +627,29 @@ func generateListPartitionExpr(ctx sessionctx.Context, tblInfo *model.TableInfo, return ret, nil } +// Clone a copy of ForListPruning +func (lp *ForListPruning) Clone() *ForListPruning { + ret := *lp + if ret.LocateExpr != nil { + ret.LocateExpr = lp.LocateExpr.Clone() + } + if ret.PruneExpr != nil { + ret.PruneExpr = lp.PruneExpr.Clone() + } + ret.PruneExprCols = make([]*expression.Column, 0, len(lp.PruneExprCols)) + for i := range lp.PruneExprCols { + c := lp.PruneExprCols[i].Clone().(*expression.Column) + ret.PruneExprCols = append(ret.PruneExprCols, c) + } + ret.ColPrunes = make([]*ForListColumnPruning, 0, len(lp.ColPrunes)) + for i := range lp.ColPrunes { + l := *lp.ColPrunes[i] + l.ExprCol = l.ExprCol.Clone().(*expression.Column) + ret.ColPrunes = append(ret.ColPrunes, &l) + } + return &ret +} + func (lp *ForListPruning) buildListPruner(ctx sessionctx.Context, tblInfo *model.TableInfo, exprCols []*expression.Column, columns []*expression.Column, names types.NameSlice) error { pi := tblInfo.GetPartitionInfo() diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index a28db76980cfa..022caefaa9ec4 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -1038,6 +1038,9 @@ func IsUntouchedIndexKValue(k, v []byte) bool { return false } vLen := len(v) + if IsTempIndexKey(k) { + return vLen > 0 && v[vLen-1] == kv.UnCommitIndexKVFlag + } if vLen <= MaxOldEncodeValueLen { return (vLen == 1 || vLen == 9) && v[vLen-1] == kv.UnCommitIndexKVFlag } @@ -1132,29 +1135,27 @@ const TempIndexPrefix = 0x7fff000000000000 const IndexIDMask = 0xffffffffffff // IndexKey2TempIndexKey generates a temporary index key. -func IndexKey2TempIndexKey(indexID int64, key []byte) { - eid := codec.EncodeIntToCmpUint(TempIndexPrefix | indexID) +func IndexKey2TempIndexKey(key []byte) { + idxIDBytes := key[prefixLen : prefixLen+idLen] + idxID := codec.DecodeCmpUintToInt(binary.BigEndian.Uint64(idxIDBytes)) + eid := codec.EncodeIntToCmpUint(TempIndexPrefix | idxID) binary.BigEndian.PutUint64(key[prefixLen:], eid) } // TempIndexKey2IndexKey generates an index key from temporary index key. -func TempIndexKey2IndexKey(originIdxID int64, tempIdxKey []byte) { - eid := codec.EncodeIntToCmpUint(originIdxID) +func TempIndexKey2IndexKey(tempIdxKey []byte) { + tmpIdxIDBytes := tempIdxKey[prefixLen : prefixLen+idLen] + tempIdxID := codec.DecodeCmpUintToInt(binary.BigEndian.Uint64(tmpIdxIDBytes)) + eid := codec.EncodeIntToCmpUint(tempIdxID & IndexIDMask) binary.BigEndian.PutUint64(tempIdxKey[prefixLen:], eid) } -// CheckTempIndexKey checks whether the input key is for a temp index. -func CheckTempIndexKey(indexKey []byte) (isTemp bool, originIdxID int64) { - var ( - indexIDKey []byte - indexID int64 - tempIndexID int64 - ) - // Get encoded indexID from key, Add uint64 8 byte length. - indexIDKey = indexKey[prefixLen : prefixLen+8] - indexID = codec.DecodeCmpUintToInt(binary.BigEndian.Uint64(indexIDKey)) - tempIndexID = int64(TempIndexPrefix) | indexID - return tempIndexID == indexID, indexID & IndexIDMask +// IsTempIndexKey checks whether the input key is for a temp index. +func IsTempIndexKey(indexKey []byte) (isTemp bool) { + indexIDKey := indexKey[prefixLen : prefixLen+8] + indexID := codec.DecodeCmpUintToInt(binary.BigEndian.Uint64(indexIDKey)) + tempIndexID := int64(TempIndexPrefix) | indexID + return tempIndexID == indexID } // TempIndexValueFlag is the flag of temporary index value. @@ -1287,14 +1288,14 @@ func (v *TempIndexValueElem) Encode(buf []byte) []byte { } // DecodeTempIndexValue decodes the temp index value. -func DecodeTempIndexValue(value []byte, isCommonHandle bool) (TempIndexValue, error) { +func DecodeTempIndexValue(value []byte) (TempIndexValue, error) { var ( values []*TempIndexValueElem err error ) for len(value) > 0 { v := &TempIndexValueElem{} - value, err = v.DecodeOne(value, isCommonHandle) + value, err = v.DecodeOne(value) if err != nil { return nil, err } @@ -1304,7 +1305,7 @@ func DecodeTempIndexValue(value []byte, isCommonHandle bool) (TempIndexValue, er } // DecodeOne decodes one temp index value element. -func (v *TempIndexValueElem) DecodeOne(b []byte, isCommonHandle bool) (remain []byte, err error) { +func (v *TempIndexValueElem) DecodeOne(b []byte) (remain []byte, err error) { flag := TempIndexValueFlag(b[0]) b = b[1:] switch flag { @@ -1316,7 +1317,6 @@ func (v *TempIndexValueElem) DecodeOne(b []byte, isCommonHandle bool) (remain [] v.KeyVer = b[0] b = b[1:] v.Distinct = true - v.Handle, err = DecodeHandleInUniqueIndexValue(v.Value, isCommonHandle) return b, err case TempIndexValueFlagNonDistinctNormal: v.Value = b[:len(b)-1] @@ -1325,10 +1325,10 @@ func (v *TempIndexValueElem) DecodeOne(b []byte, isCommonHandle bool) (remain [] case TempIndexValueFlagDeleted: hLen := (uint16(b[0]) << 8) + uint16(b[1]) b = b[2:] - if isCommonHandle { - v.Handle, _ = kv.NewCommonHandle(b[:hLen]) + if hLen == idLen { + v.Handle = decodeIntHandleInIndexValue(b[:idLen]) } else { - v.Handle = decodeIntHandleInIndexValue(b[:hLen]) + v.Handle, _ = kv.NewCommonHandle(b[:hLen]) } b = b[hLen:] v.KeyVer = b[0] diff --git a/tablecodec/tablecodec_test.go b/tablecodec/tablecodec_test.go index adc4ccc78c13b..0e75b38741f72 100644 --- a/tablecodec/tablecodec_test.go +++ b/tablecodec/tablecodec_test.go @@ -15,6 +15,7 @@ package tablecodec import ( + "encoding/binary" "fmt" "math" "testing" @@ -588,6 +589,11 @@ func TestUntouchedIndexKValue(t *testing.T) { untouchedIndexKey := []byte("t00000001_i000000001") untouchedIndexValue := []byte{0, 0, 0, 0, 0, 0, 0, 1, 49} require.True(t, IsUntouchedIndexKValue(untouchedIndexKey, untouchedIndexValue)) + IndexKey2TempIndexKey(untouchedIndexKey) + require.True(t, IsUntouchedIndexKValue(untouchedIndexKey, untouchedIndexValue)) + elem := TempIndexValueElem{Handle: kv.IntHandle(1), Delete: true, Distinct: true} + tmpIdxVal := elem.Encode(nil) + require.False(t, IsUntouchedIndexKValue(untouchedIndexKey, tmpIdxVal)) } func TestTempIndexKey(t *testing.T) { @@ -597,14 +603,14 @@ func TestTempIndexKey(t *testing.T) { tableID := int64(4) indexID := int64(5) indexKey := EncodeIndexSeekKey(tableID, indexID, encodedValue) - IndexKey2TempIndexKey(indexID, indexKey) + IndexKey2TempIndexKey(indexKey) tid, iid, _, err := DecodeKeyHead(indexKey) require.NoError(t, err) require.Equal(t, tid, tableID) require.NotEqual(t, indexID, iid) require.Equal(t, indexID, iid&IndexIDMask) - TempIndexKey2IndexKey(indexID, indexKey) + TempIndexKey2IndexKey(indexKey) tid, iid, _, err = DecodeKeyHead(indexKey) require.NoError(t, err) require.Equal(t, tid, tableID) @@ -624,7 +630,7 @@ func TestTempIndexValueCodec(t *testing.T) { } val := tempIdxVal.Encode(nil) var newTempIdxVal TempIndexValueElem - remain, err := newTempIdxVal.DecodeOne(val, false) + remain, err := newTempIdxVal.DecodeOne(val) require.NoError(t, err) require.Equal(t, 0, len(remain)) require.EqualValues(t, tempIdxVal, newTempIdxVal) @@ -637,11 +643,12 @@ func TestTempIndexValueCodec(t *testing.T) { } newTempIdxVal = TempIndexValueElem{} val = tempIdxVal.Encode(nil) - remain, err = newTempIdxVal.DecodeOne(val, false) + remain, err = newTempIdxVal.DecodeOne(val) require.NoError(t, err) require.Equal(t, 0, len(remain)) - require.Equal(t, newTempIdxVal.Handle.IntValue(), int64(100)) - newTempIdxVal.Handle = nil + handle, err := DecodeHandleInUniqueIndexValue(newTempIdxVal.Value, false) + require.NoError(t, err) + require.Equal(t, handle.IntValue(), int64(100)) require.EqualValues(t, tempIdxVal, newTempIdxVal) tempIdxVal = TempIndexValueElem{ @@ -650,7 +657,7 @@ func TestTempIndexValueCodec(t *testing.T) { } newTempIdxVal = TempIndexValueElem{} val = tempIdxVal.Encode(nil) - remain, err = newTempIdxVal.DecodeOne(val, false) + remain, err = newTempIdxVal.DecodeOne(val) require.NoError(t, err) require.Equal(t, 0, len(remain)) require.EqualValues(t, tempIdxVal, newTempIdxVal) @@ -663,7 +670,7 @@ func TestTempIndexValueCodec(t *testing.T) { } newTempIdxVal = TempIndexValueElem{} val = tempIdxVal.Encode(nil) - remain, err = newTempIdxVal.DecodeOne(val, false) + remain, err = newTempIdxVal.DecodeOne(val) require.NoError(t, err) require.Equal(t, 0, len(remain)) require.EqualValues(t, tempIdxVal, newTempIdxVal) @@ -691,10 +698,21 @@ func TestTempIndexValueCodec(t *testing.T) { val = tempIdxVal2.Encode(val) val = tempIdxVal3.Encode(val) var result TempIndexValue - result, err = DecodeTempIndexValue(val, false) + result, err = DecodeTempIndexValue(val) require.NoError(t, err) require.Equal(t, 3, len(result)) + for i := 0; i < 3; i++ { + if result[i].Handle == nil { + uv := binary.BigEndian.Uint64(result[i].Value) + result[i].Handle = kv.IntHandle(int64(uv)) + } + } require.Equal(t, result[0].Handle.IntValue(), int64(100)) require.Equal(t, result[1].Handle.IntValue(), int64(100)) require.Equal(t, result[2].Handle.IntValue(), int64(101)) + + elem := TempIndexValueElem{Handle: kv.IntHandle(100), KeyVer: 'b', Delete: true, Distinct: true} + val = elem.Encode(nil) + isUnique := IndexKVIsUnique(val) + require.False(t, isUnique) } diff --git a/telemetry/cte_test/BUILD.bazel b/telemetry/cte_test/BUILD.bazel index c6d60eda945df..d0b5f15f75561 100644 --- a/telemetry/cte_test/BUILD.bazel +++ b/telemetry/cte_test/BUILD.bazel @@ -7,15 +7,11 @@ go_test( flaky = True, race = "on", deps = [ - "//config", "//domain", "//kv", "//session", "//store/mockstore", - "//telemetry", - "//testkit", "//testkit/testsetup", - "@com_github_jeffail_gabs_v2//:gabs", "@com_github_stretchr_testify//require", "@io_etcd_go_etcd_tests_v3//integration", "@io_opencensus_go//stats/view", diff --git a/telemetry/cte_test/cte_test.go b/telemetry/cte_test/cte_test.go index fac26ddb2f403..c8f04eb2df1d4 100644 --- a/telemetry/cte_test/cte_test.go +++ b/telemetry/cte_test/cte_test.go @@ -18,14 +18,10 @@ import ( "runtime" "testing" - "github.com/Jeffail/gabs/v2" - "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" - "github.com/pingcap/tidb/telemetry" - "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/testkit/testsetup" "github.com/stretchr/testify/require" "go.etcd.io/etcd/tests/v3/integration" @@ -55,35 +51,43 @@ func TestCTEPreviewAndReport(t *testing.T) { s := newSuite(t) defer s.close() - config.GetGlobalConfig().EnableTelemetry = true - - tk := testkit.NewTestKit(t, s.store) - tk.MustExec("use test") - tk.MustExec("with cte as (select 1) select * from cte") - tk.MustExec("with recursive cte as (select 1) select * from cte") - tk.MustExec("with recursive cte(n) as (select 1 union select * from cte where n < 5) select * from cte") - tk.MustExec("select 1") - - res, err := telemetry.PreviewUsageData(s.se, s.etcdCluster.RandClient()) - require.NoError(t, err) + // By disableing telemetry by default, the global sysvar **and** config file defaults + // are all set to false, so that enabling telemetry in test become more complex. + // As telemetry is a feature that almost no user will manually enable, I'd remove these + // tests for now. + // They should be uncommented once the default behavious changed back to enabled in the + // future, otherwise they could just be deleted. + /* + config.GetGlobalConfig().EnableTelemetry = true + + tk := testkit.NewTestKit(t, s.store) + tk.MustExec("use test") + tk.MustExec("with cte as (select 1) select * from cte") + tk.MustExec("with recursive cte as (select 1) select * from cte") + tk.MustExec("with recursive cte(n) as (select 1 union select * from cte where n < 5) select * from cte") + tk.MustExec("select 1") + + res, err := telemetry.PreviewUsageData(s.se, s.etcdCluster.RandClient()) + require.NoError(t, err) - jsonParsed, err := gabs.ParseJSON([]byte(res)) - require.NoError(t, err) - require.Equal(t, 1, int(jsonParsed.Path("featureUsage.cte.nonRecursiveCTEUsed").Data().(float64))) - require.Equal(t, 1, int(jsonParsed.Path("featureUsage.cte.recursiveUsed").Data().(float64))) - require.Equal(t, 4, int(jsonParsed.Path("featureUsage.cte.nonCTEUsed").Data().(float64))) + jsonParsed, err := gabs.ParseJSON([]byte(res)) + require.NoError(t, err) + require.Equal(t, 1, int(jsonParsed.Path("featureUsage.cte.nonRecursiveCTEUsed").Data().(float64))) + require.Equal(t, 1, int(jsonParsed.Path("featureUsage.cte.recursiveUsed").Data().(float64))) + require.Equal(t, 4, int(jsonParsed.Path("featureUsage.cte.nonCTEUsed").Data().(float64))) - err = telemetry.ReportUsageData(s.se, s.etcdCluster.RandClient()) - require.NoError(t, err) + err = telemetry.ReportUsageData(s.se, s.etcdCluster.RandClient()) + require.NoError(t, err) - res, err = telemetry.PreviewUsageData(s.se, s.etcdCluster.RandClient()) - require.NoError(t, err) + res, err = telemetry.PreviewUsageData(s.se, s.etcdCluster.RandClient()) + require.NoError(t, err) - jsonParsed, err = gabs.ParseJSON([]byte(res)) - require.NoError(t, err) - require.Equal(t, 0, int(jsonParsed.Path("featureUsage.cte.nonRecursiveCTEUsed").Data().(float64))) - require.Equal(t, 0, int(jsonParsed.Path("featureUsage.cte.recursiveUsed").Data().(float64))) - require.Equal(t, 0, int(jsonParsed.Path("featureUsage.cte.nonCTEUsed").Data().(float64))) + jsonParsed, err = gabs.ParseJSON([]byte(res)) + require.NoError(t, err) + require.Equal(t, 0, int(jsonParsed.Path("featureUsage.cte.nonRecursiveCTEUsed").Data().(float64))) + require.Equal(t, 0, int(jsonParsed.Path("featureUsage.cte.recursiveUsed").Data().(float64))) + require.Equal(t, 0, int(jsonParsed.Path("featureUsage.cte.nonCTEUsed").Data().(float64))) + */ } type testSuite struct { diff --git a/telemetry/telemetry_test.go b/telemetry/telemetry_test.go index 8238614f111cf..13c58bdbd5003 100644 --- a/telemetry/telemetry_test.go +++ b/telemetry/telemetry_test.go @@ -68,25 +68,36 @@ func TestPreview(t *testing.T) { require.NoError(t, err) require.Equal(t, "", r) - trackingID, err := telemetry.ResetTrackingID(etcdCluster.RandClient()) - require.NoError(t, err) - - config.GetGlobalConfig().EnableTelemetry = true - r, err = telemetry.PreviewUsageData(se, etcdCluster.RandClient()) - require.NoError(t, err) - - jsonParsed, err := gabs.ParseJSON([]byte(r)) - require.NoError(t, err) - require.Equal(t, trackingID, jsonParsed.Path("trackingId").Data().(string)) - // Apple M1 doesn't contain cpuFlags - if !(runtime.GOARCH == "arm64" && runtime.GOOS == "darwin") { - require.True(t, jsonParsed.ExistsP("hostExtra.cpuFlags")) - } - require.True(t, jsonParsed.ExistsP("hostExtra.os")) - require.Len(t, jsonParsed.Path("instances").Children(), 2) - require.Equal(t, "tidb", jsonParsed.Path("instances.0.instanceType").Data().(string)) - require.Equal(t, "tikv", jsonParsed.Path("instances.1.instanceType").Data().(string)) - require.True(t, jsonParsed.ExistsP("hardware")) + // By disableing telemetry by default, the global sysvar **and** config file defaults + // are all set to false, so that enabling telemetry in test become more complex. + // As telemetry is a feature that almost no user will manually enable, I'd remove these + // tests for now. + // They should be uncommented once the default behavious changed back to enabled in the + // future, otherwise they could just be deleted. + /* + trackingID, err := telemetry.ResetTrackingID(etcdCluster.RandClient()) + require.NoError(t, err) + + config.GetGlobalConfig().EnableTelemetry = true + telemetryEnabled, err := telemetry.IsTelemetryEnabled(se) + require.NoError(t, err) + require.True(t, telemetryEnabled) + r, err = telemetry.PreviewUsageData(se, etcdCluster.RandClient()) + require.NoError(t, err) + + jsonParsed, err := gabs.ParseJSON([]byte(r)) + require.NoError(t, err) + require.Equal(t, trackingID, jsonParsed.Path("trackingId").Data().(string)) + // Apple M1 doesn't contain cpuFlags + if !(runtime.GOARCH == "arm64" && runtime.GOOS == "darwin") { + require.True(t, jsonParsed.ExistsP("hostExtra.cpuFlags")) + } + require.True(t, jsonParsed.ExistsP("hostExtra.os")) + require.Len(t, jsonParsed.Path("instances").Children(), 2) + require.Equal(t, "tidb", jsonParsed.Path("instances.0.instanceType").Data().(string)) + require.Equal(t, "tikv", jsonParsed.Path("instances.1.instanceType").Data().(string)) + require.True(t, jsonParsed.ExistsP("hardware")) + */ _, err = se.Execute(context.Background(), "SET @@global.tidb_enable_telemetry = 0") require.NoError(t, err) diff --git a/testkit/mocksessionmanager.go b/testkit/mocksessionmanager.go index 550ff69132d91..67280bc2e4cbe 100644 --- a/testkit/mocksessionmanager.go +++ b/testkit/mocksessionmanager.go @@ -145,7 +145,7 @@ func (msm *MockSessionManager) KillNonFlashbackClusterConn() { } } -// CheckOldRunningTxn implement SessionManager interface. +// CheckOldRunningTxn is to get all startTS of every transactions running in the current internal sessions func (msm *MockSessionManager) CheckOldRunningTxn(job2ver map[int64]int64, job2ids map[int64]string) { msm.mu.Lock() for _, se := range msm.conn { @@ -153,25 +153,3 @@ func (msm *MockSessionManager) CheckOldRunningTxn(job2ver map[int64]int64, job2i } msm.mu.Unlock() } - -// GetMinStartTS implements SessionManager interface. -func (msm *MockSessionManager) GetMinStartTS(lowerBound uint64) (ts uint64) { - msm.PSMu.RLock() - defer msm.PSMu.RUnlock() - if len(msm.PS) > 0 { - for _, pi := range msm.PS { - if thisTS := pi.GetMinStartTS(lowerBound); thisTS > lowerBound && (thisTS < ts || ts == 0) { - ts = thisTS - } - } - return - } - msm.mu.Lock() - defer msm.mu.Unlock() - for _, s := range msm.conn { - if thisTS := s.ShowProcess().GetMinStartTS(lowerBound); thisTS > lowerBound && (thisTS < ts || ts == 0) { - ts = thisTS - } - } - return -} diff --git a/tests/realtikvtest/addindextest/integration_test.go b/tests/realtikvtest/addindextest/integration_test.go index df9c9baa05931..7fe3378fc08a7 100644 --- a/tests/realtikvtest/addindextest/integration_test.go +++ b/tests/realtikvtest/addindextest/integration_test.go @@ -355,5 +355,15 @@ func TestAddIndexSplitTableRanges(t *testing.T) { ddl.SetBackfillTaskChanSizeForTest(7) tk.MustExec("alter table t add index idx_2(b);") tk.MustExec("admin check table t;") + + tk.MustExec("drop table t;") + tk.MustExec("create table t (a int primary key, b int);") + for i := 0; i < 8; i++ { + tk.MustExec(fmt.Sprintf("insert into t values (%d, %d);", i*10000, i*10000)) + } + tk.MustQuery("split table t by (10000),(20000),(30000),(40000),(50000),(60000);").Check(testkit.Rows("6 1")) + ddl.SetBackfillTaskChanSizeForTest(4) + tk.MustExec("alter table t add unique index idx(b);") + tk.MustExec("admin check table t;") ddl.SetBackfillTaskChanSizeForTest(1024) } diff --git a/tests/realtikvtest/pessimistictest/pessimistic_test.go b/tests/realtikvtest/pessimistictest/pessimistic_test.go index ae7545e0e91f6..ad40080c765a9 100644 --- a/tests/realtikvtest/pessimistictest/pessimistic_test.go +++ b/tests/realtikvtest/pessimistictest/pessimistic_test.go @@ -2816,60 +2816,37 @@ func TestAsyncCommitCalTSFail(t *testing.T) { tk2.MustExec("commit") } -func TestChangeLockToPut(t *testing.T) { +func TestIssue28011(t *testing.T) { store := realtikvtest.CreateMockStoreAndSetup(t) tk := testkit.NewTestKit(t, store) - tk2 := testkit.NewTestKit(t, store) tk.MustExec("use test") - tk2.MustExec("use test") - - tk.MustExec("drop table if exists tk") - tk.MustExec("create table t1(c1 varchar(20) key, c2 int, c3 int, unique key k1(c2), key k2(c3))") - tk.MustExec(`insert into t1 values ("1", 1, 1), ("2", 2, 2), ("3", 3, 3)`) - - // Test point get change lock to put. - for _, mode := range []string{"REPEATABLE-READ", "READ-COMMITTED"} { - tk.MustExec(fmt.Sprintf(`set tx_isolation = "%s"`, mode)) - tk.MustExec("begin pessimistic") - tk.MustQuery(`select * from t1 where c1 = "1" for update`).Check(testkit.Rows("1 1 1")) - tk.MustExec("commit") - tk.MustExec("begin pessimistic") - tk.MustQuery(`select * from t1 where c1 = "1" for update`).Check(testkit.Rows("1 1 1")) - tk.MustExec("commit") - tk.MustExec("admin check table t1") - tk2.MustExec("begin") - tk2.MustQuery(`select * from t1 use index(k1) where c2 = "1" for update`).Check(testkit.Rows("1 1 1")) - tk2.MustQuery(`select * from t1 use index(k1) where c2 = "3" for update`).Check(testkit.Rows("3 3 3")) - tk2.MustExec("commit") - tk2.MustExec("begin") - tk2.MustQuery(`select * from t1 use index(k2) where c3 = 1`).Check(testkit.Rows("1 1 1")) - tk2.MustQuery("select * from t1 use index(k2) where c3 > 1").Check(testkit.Rows("2 2 2", "3 3 3")) - tk2.MustExec("commit") - } - // Test batch point get change lock to put. - for _, mode := range []string{"REPEATABLE-READ", "READ-COMMITTED"} { - tk.MustExec(fmt.Sprintf(`set tx_isolation = "%s"`, mode)) - tk.MustExec("begin pessimistic") - tk.MustQuery(`select * from t1 where c1 in ("1", "5", "3") for update`).Check(testkit.Rows("1 1 1", "3 3 3")) - tk.MustExec("commit") - tk.MustExec("begin pessimistic") - tk.MustQuery(`select * from t1 where c1 in ("1", "2", "8") for update`).Check(testkit.Rows("1 1 1", "2 2 2")) - tk.MustExec("commit") - tk.MustExec("admin check table t1") - tk2.MustExec("begin") - tk2.MustQuery(`select * from t1 use index(k1) where c2 in ("1", "2", "3") for update`).Check(testkit.Rows("1 1 1", "2 2 2", "3 3 3")) - tk2.MustQuery(`select * from t1 use index(k2) where c2 in ("2") for update`).Check(testkit.Rows("2 2 2")) - tk2.MustExec("commit") - tk2.MustExec("begin") - tk2.MustQuery(`select * from t1 use index(k2) where c3 in (5, 8)`).Check(testkit.Rows()) - tk2.MustQuery(`select * from t1 use index(k2) where c3 in (1, 8) for update`).Check(testkit.Rows("1 1 1")) - tk2.MustQuery(`select * from t1 use index(k2) where c3 > 1`).Check(testkit.Rows("2 2 2", "3 3 3")) - tk2.MustExec("commit") + for _, tt := range []struct { + name string + lockQuery string + finalRows [][]interface{} + }{ + {"Update", "update t set b = 'x' where a = 'a'", testkit.Rows("a x", "b y", "c z")}, + {"BatchUpdate", "update t set b = 'x' where a in ('a', 'b', 'c')", testkit.Rows("a x", "b y", "c x")}, + {"SelectForUpdate", "select a from t where a = 'a' for update", testkit.Rows("a x", "b y", "c z")}, + {"BatchSelectForUpdate", "select a from t where a in ('a', 'b', 'c') for update", testkit.Rows("a x", "b y", "c z")}, + } { + t.Run(tt.name, func(t *testing.T) { + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a varchar(10) primary key nonclustered, b varchar(10))") + tk.MustExec("insert into t values ('a', 'x'), ('b', 'x'), ('c', 'z')") + tk.MustExec("begin") + tk.MustExec(tt.lockQuery) + tk.MustQuery("select a from t").Check(testkit.Rows("a", "b", "c")) + tk.MustExec("replace into t values ('b', 'y')") + tk.MustQuery("select a from t").Check(testkit.Rows("a", "b", "c")) + tk.MustQuery("select a, b from t order by a").Check(tt.finalRows) + tk.MustExec("commit") + tk.MustQuery("select a, b from t order by a").Check(tt.finalRows) + tk.MustExec("admin check table t") + }) } - - tk.MustExec("admin check table t1") } func createTable(part bool, columnNames []string, columnTypes []string) string { diff --git a/types/convert.go b/types/convert.go index a3e5b572e0add..6ae428d7242c0 100644 --- a/types/convert.go +++ b/types/convert.go @@ -246,7 +246,7 @@ func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBoun if intStr == "" { intStr = "0" } - if sc.ShouldClipToZero() && intStr[0] == '-' { + if intStr[0] == '-' { return 0, overflow(str, tp) } @@ -255,8 +255,7 @@ func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBoun round++ } - upperBound -= round - upperStr := strconv.FormatUint(upperBound, 10) + upperStr := strconv.FormatUint(upperBound-round, 10) if len(intStr) > len(upperStr) || (len(intStr) == len(upperStr) && intStr > upperStr) { return upperBound, overflow(str, tp) diff --git a/types/convert_test.go b/types/convert_test.go index b28560dfbdf75..db9cb2ba691fd 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -1275,8 +1275,9 @@ func TestConvertDecimalStrToUint(t *testing.T) { {"9223372036854775807.4999", 9223372036854775807, true}, {"18446744073709551614.55", 18446744073709551615, true}, {"18446744073709551615.344", 18446744073709551615, true}, - {"18446744073709551615.544", 0, false}, + {"18446744073709551615.544", 18446744073709551615, false}, {"-111.111", 0, false}, + {"-10000000000000000000.0", 0, false}, } for _, ca := range cases { result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, ca.input, math.MaxUint64, 0) @@ -1284,7 +1285,15 @@ func TestConvertDecimalStrToUint(t *testing.T) { require.Error(t, err) } else { require.NoError(t, err) - require.Equal(t, ca.result, result) } + require.Equal(t, ca.result, result, "input=%v", ca.input) } + + result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, "-99.0", math.MaxUint8, 0) + require.Error(t, err) + require.Equal(t, uint64(0), result) + + result, err = convertDecimalStrToUint(&stmtctx.StatementContext{}, "-100.0", math.MaxUint8, 0) + require.Error(t, err) + require.Equal(t, uint64(0), result) } diff --git a/util/memory/tracker.go b/util/memory/tracker.go index c67a2ddcf6904..b4ffea612ec53 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -310,7 +310,7 @@ func (t *Tracker) Detach() { t.DetachFromGlobalTracker() return } - if parent.IsRootTrackerOfSess && t.label == LabelForSQLText { + if parent.IsRootTrackerOfSess && t.label != LabelForMemDB { parent.actionMuForHardLimit.Lock() parent.actionMuForHardLimit.actionOnExceed = nil parent.actionMuForHardLimit.Unlock() diff --git a/util/processinfo.go b/util/processinfo.go index dee4f4ea30a53..77f35ef94a5ee 100644 --- a/util/processinfo.go +++ b/util/processinfo.go @@ -31,14 +31,6 @@ import ( "github.com/tikv/client-go/v2/oracle" ) -// ProtectedTSList holds a list of timestamps that should delay GC. -type ProtectedTSList interface { - // HoldTS holds the timestamp to prevent its data from being GCed. - HoldTS(ts uint64) (unhold func()) - // GetMinProtectedTS returns the minimum protected timestamp that greater than `lowerBound` (0 if no such one). - GetMinProtectedTS(lowerBound uint64) (ts uint64) -} - // OOMAlarmVariablesInfo is a struct for OOM alarm variables. type OOMAlarmVariablesInfo struct { SessionAnalyzeVersion int @@ -48,7 +40,6 @@ type OOMAlarmVariablesInfo struct { // ProcessInfo is a struct used for show processlist statement. type ProcessInfo struct { - ProtectedTSList Time time.Time ExpensiveLogTime time.Time Plan interface{} @@ -138,23 +129,6 @@ func (pi *ProcessInfo) ToRow(tz *time.Location) []interface{} { return append(pi.ToRowForShow(true), pi.Digest, bytesConsumed, diskConsumed, pi.txnStartTs(tz)) } -// GetMinStartTS returns the minimum start-ts (used to delay GC) that greater than `lowerBound` (0 if no such one). -func (pi *ProcessInfo) GetMinStartTS(lowerBound uint64) (ts uint64) { - if pi == nil { - return - } - if thisTS := pi.CurTxnStartTS; thisTS > lowerBound && (thisTS < ts || ts == 0) { - ts = thisTS - } - if pi.ProtectedTSList == nil { - return - } - if thisTS := pi.GetMinProtectedTS(lowerBound); thisTS > lowerBound && (thisTS < ts || ts == 0) { - ts = thisTS - } - return -} - // ascServerStatus is a slice of all defined server status in ascending order. var ascServerStatus = []uint16{ mysql.ServerStatusInTrans, @@ -223,8 +197,6 @@ type SessionManager interface { CheckOldRunningTxn(job2ver map[int64]int64, job2ids map[int64]string) // KillNonFlashbackClusterConn kill all non flashback cluster connections. KillNonFlashbackClusterConn() - // GetMinStartTS returns the minimum start-ts (used to delay GC) that greater than `lowerBound` (0 if no such one). - GetMinStartTS(lowerBound uint64) uint64 } // GlobalConnID is the global connection ID, providing UNIQUE connection IDs across the whole TiDB cluster.