From aaf15a700df46e2a4925be884687b11c297997d8 Mon Sep 17 00:00:00 2001 From: hackersean Date: Thu, 22 Dec 2022 15:42:55 +0800 Subject: [PATCH] This is an automated cherry-pick of #39932 Signed-off-by: ti-chi-bot --- kv/mpp.go | 4 + metrics/metrics.go | 55 + metrics/server.go | 12 + planner/core/fragment.go | 2 +- sessionctx/sessionstates/session_states.go | 85 + .../sessionstates/session_states_test.go | 1468 +++++++++++++++++ sessionctx/variable/session.go | 184 +++ store/copr/BUILD.bazel | 91 + store/copr/batch_coprocessor.go | 164 +- store/copr/batch_coprocessor_test.go | 4 +- store/copr/mpp.go | 22 + store/copr/mpp_probe.go | 270 +++ store/copr/mpp_probe_test.go | 177 ++ store/copr/store.go | 1 + .../sessiontest/session_fail_test.go | 205 +++ tidb-server/BUILD.bazel | 107 ++ tidb-server/main.go | 8 + 17 files changed, 2853 insertions(+), 6 deletions(-) create mode 100644 sessionctx/sessionstates/session_states.go create mode 100644 sessionctx/sessionstates/session_states_test.go create mode 100644 store/copr/BUILD.bazel create mode 100644 store/copr/mpp_probe.go create mode 100644 store/copr/mpp_probe_test.go create mode 100644 tests/realtikvtest/sessiontest/session_fail_test.go create mode 100644 tidb-server/BUILD.bazel diff --git a/kv/mpp.go b/kv/mpp.go index 231f0cccb2325..48383234d5934 100644 --- a/kv/mpp.go +++ b/kv/mpp.go @@ -78,7 +78,11 @@ type MPPDispatchRequest struct { type MPPClient interface { // ConstructMPPTasks schedules task for a plan fragment. // TODO:: This interface will be refined after we support more executors. +<<<<<<< HEAD ConstructMPPTasks(context.Context, *MPPBuildTasksRequest, map[string]time.Time, time.Duration) ([]MPPTaskMeta, error) +======= + ConstructMPPTasks(context.Context, *MPPBuildTasksRequest, time.Duration) ([]MPPTaskMeta, error) +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) // DispatchMPPTasks dispatches ALL mpp requests at once, and returns an iterator that transfers the data. DispatchMPPTasks(ctx context.Context, vars interface{}, reqs []*MPPDispatchRequest, needTriggerFallback bool, startTs uint64) Response diff --git a/metrics/metrics.go b/metrics/metrics.go index b24eb7b152750..3f7b434302633 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -154,6 +154,7 @@ func RegisterMetrics() { prometheus.MustRegister(TokenGauge) prometheus.MustRegister(ConfigStatus) prometheus.MustRegister(TiFlashQueryTotalCounter) + prometheus.MustRegister(TiFlashFailedMPPStoreState) prometheus.MustRegister(SmallTxnWriteDuration) prometheus.MustRegister(TxnWriteThroughput) prometheus.MustRegister(LoadSysVarCacheCounter) @@ -167,3 +168,57 @@ func RegisterMetrics() { tikvmetrics.RegisterMetrics() tikvmetrics.TiKVPanicCounter = PanicCounter // reset tidb metrics for tikv metrics } +<<<<<<< HEAD +======= + +var mode struct { + sync.Mutex + isSimplified bool +} + +// ToggleSimplifiedMode is used to register/unregister the metrics that unused by grafana. +func ToggleSimplifiedMode(simplified bool) { + var unusedMetricsByGrafana = []prometheus.Collector{ + StatementDeadlockDetectDuration, + ValidateReadTSFromPDCount, + LoadTableCacheDurationHistogram, + TxnWriteThroughput, + SmallTxnWriteDuration, + InfoCacheCounters, + ReadFromTableCacheCounter, + TiFlashQueryTotalCounter, + TiFlashFailedMPPStoreState, + CampaignOwnerCounter, + NonTransactionalDMLCount, + MemoryUsage, + TokenGauge, + tikvmetrics.TiKVRawkvSizeHistogram, + tikvmetrics.TiKVRawkvCmdHistogram, + tikvmetrics.TiKVReadThroughput, + tikvmetrics.TiKVSmallReadDuration, + tikvmetrics.TiKVBatchWaitOverLoad, + tikvmetrics.TiKVBatchClientRecycle, + tikvmetrics.TiKVRequestRetryTimesHistogram, + tikvmetrics.TiKVStatusDuration, + } + mode.Lock() + defer mode.Unlock() + if mode.isSimplified == simplified { + return + } + mode.isSimplified = simplified + if simplified { + for _, m := range unusedMetricsByGrafana { + prometheus.Unregister(m) + } + } else { + for _, m := range unusedMetricsByGrafana { + err := prometheus.Register(m) + if err != nil { + logutil.BgLogger().Error("cannot register metrics", zap.Error(err)) + break + } + } + } +} +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) diff --git a/metrics/server.go b/metrics/server.go index 440a833e6f03e..0f4058c366286 100644 --- a/metrics/server.go +++ b/metrics/server.go @@ -230,7 +230,19 @@ var ( Help: "Counter of TiFlash queries.", }, []string{LblType, LblResult}) +<<<<<<< HEAD PDApiExecutionHistogram = prometheus.NewHistogramVec( +======= + TiFlashFailedMPPStoreState = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "server", + Name: "tiflash_failed_store", + Help: "Statues of failed tiflash mpp store,-1 means detector heartbeat,0 means reachable,1 means abnormal.", + }, []string{LblAddress}) + + PDAPIExecutionHistogram = prometheus.NewHistogramVec( +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) prometheus.HistogramOpts{ Namespace: "tidb", Subsystem: "server", diff --git a/planner/core/fragment.go b/planner/core/fragment.go index 383e745d47d34..ae3de04277444 100644 --- a/planner/core/fragment.go +++ b/planner/core/fragment.go @@ -368,7 +368,7 @@ func (e *mppTaskGenerator) constructMPPTasksForSinglePartitionTable(ctx context. logutil.BgLogger().Warn("MPP store fail ttl is invalid", zap.Error(err)) ttl = 30 * time.Second } - metas, err := e.ctx.GetMPPClient().ConstructMPPTasks(ctx, req, e.ctx.GetSessionVars().MPPStoreLastFailTime, ttl) + metas, err := e.ctx.GetMPPClient().ConstructMPPTasks(ctx, req, ttl) if err != nil { return nil, errors.Trace(err) } diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go new file mode 100644 index 0000000000000..c9e1652a9c1df --- /dev/null +++ b/sessionctx/sessionstates/session_states.go @@ -0,0 +1,85 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionstates + +import ( + "github.com/pingcap/tidb/errno" + ptypes "github.com/pingcap/tidb/parser/types" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/dbterror" +) + +// SessionStateType is the type of session states. +type SessionStateType int + +var ( + // ErrCannotMigrateSession indicates the session cannot be migrated. + ErrCannotMigrateSession = dbterror.ClassSession.NewStd(errno.ErrCannotMigrateSession) +) + +// These enums represents the types of session state handlers. +const ( + // StatePrepareStmt represents prepared statements. + StatePrepareStmt SessionStateType = iota + // StateBinding represents session SQL bindings. + StateBinding +) + +// PreparedStmtInfo contains the information about prepared statements, both text and binary protocols. +type PreparedStmtInfo struct { + Name string `json:"name,omitempty"` + StmtText string `json:"text"` + StmtDB string `json:"db,omitempty"` + ParamTypes []byte `json:"types,omitempty"` +} + +// QueryInfo represents the information of last executed query. It's used to expose information for test purpose. +type QueryInfo struct { + TxnScope string `json:"txn_scope"` + StartTS uint64 `json:"start_ts"` + ForUpdateTS uint64 `json:"for_update_ts"` + ErrMsg string `json:"error,omitempty"` +} + +// LastDDLInfo represents the information of last DDL. It's used to expose information for test purpose. +type LastDDLInfo struct { + Query string `json:"query"` + SeqNum uint64 `json:"seq_num"` +} + +// SessionStates contains all the states in the session that should be migrated when the session +// is migrated to another server. It is shown by `show session_states` and recovered by `set session_states`. +type SessionStates struct { + UserVars map[string]*types.Datum `json:"user-var-values,omitempty"` + UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"` + SystemVars map[string]string `json:"sys-vars,omitempty"` + PreparedStmts map[uint32]*PreparedStmtInfo `json:"prepared-stmts,omitempty"` + PreparedStmtID uint32 `json:"prepared-stmt-id,omitempty"` + Status uint16 `json:"status,omitempty"` + CurrentDB string `json:"current-db,omitempty"` + LastTxnInfo string `json:"txn-info,omitempty"` + LastQueryInfo *QueryInfo `json:"query-info,omitempty"` + LastDDLInfo *LastDDLInfo `json:"ddl-info,omitempty"` + LastFoundRows uint64 `json:"found-rows,omitempty"` + FoundInPlanCache bool `json:"in-plan-cache,omitempty"` + FoundInBinding bool `json:"in-binding,omitempty"` + SequenceLatestValues map[int64]int64 `json:"seq-values,omitempty"` + LastAffectedRows int64 `json:"affected-rows,omitempty"` + LastInsertID uint64 `json:"last-insert-id,omitempty"` + Warnings []stmtctx.SQLWarn `json:"warnings,omitempty"` + // Define it as string to avoid cycle import. + Bindings string `json:"bindings,omitempty"` +} diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go new file mode 100644 index 0000000000000..21de8d53727d6 --- /dev/null +++ b/sessionctx/sessionstates/session_states_test.go @@ -0,0 +1,1468 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionstates_test + +import ( + "context" + "encoding/binary" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/server" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/sem" + "github.com/stretchr/testify/require" +) + +func TestGrammar(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + rows := tk.MustQuery("show session_states").Rows() + require.Len(t, rows, 1) + tk.MustExec("set session_states '{}'") + tk.MustGetErrCode("set session_states 1", errno.ErrParse) +} + +func TestUserVars(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t1(" + + "j json, b blob, s varchar(255), st set('red', 'green', 'blue'), en enum('red', 'green', 'blue'))") + tk.MustExec("insert into test.t1 values('{\"color:\": \"red\"}', 'red', 'red', 'red,green', 'red')") + + tests := []string{ + "", + "set @%s=null", + "set @%s=1", + "set @%s=1.0e10", + "set @%s=1.0-1", + "set @%s=now()", + "set @%s=1, @%s=1.0-1", + "select @%s:=1+1", + // TiDB doesn't support following features. + //"select j into @%s from test.t1", + //"select j,b,s,st,en into @%s,@%s,@%s,@%s,@%s from test.t1", + } + + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + namesNum := strings.Count(tt, "%s") + names := make([]any, 0, namesNum) + for i := 0; i < namesNum; i++ { + names = append(names, fmt.Sprintf("a%d", i)) + } + var sql string + if len(tt) > 0 { + sql = fmt.Sprintf(tt, names...) + tk1.MustExec(sql) + } + showSessionStatesAndSet(t, tk1, tk2) + for _, name := range names { + sql := fmt.Sprintf("select @%s", name) + msg := fmt.Sprintf("sql: %s, var name: %s", sql, name) + value1 := tk1.MustQuery(sql).Rows()[0][0] + value2 := tk2.MustQuery(sql).Rows()[0][0] + require.Equal(t, value1, value2, msg) + } + } +} + +func TestSystemVars(t *testing.T) { + store := testkit.CreateMockStore(t) + + tests := []struct { + stmts []string + varName string + inSessionStates bool + checkStmt string + expectedValue string + }{ + { + // normal variable + inSessionStates: true, + varName: variable.TiDBMaxTiFlashThreads, + expectedValue: strconv.Itoa(variable.DefTiFlashMaxThreads), + }, + { + // hidden variable + inSessionStates: false, + varName: variable.TiDBTxnReadTS, + }, + { + // none-scoped variable + inSessionStates: false, + varName: variable.DataDir, + expectedValue: "/usr/local/mysql/data/", + }, + { + // instance-scoped variable + inSessionStates: false, + varName: variable.TiDBGeneralLog, + expectedValue: "0", + }, + { + // global-scoped variable + inSessionStates: false, + varName: variable.TiDBAutoAnalyzeStartTime, + expectedValue: variable.DefAutoAnalyzeStartTime, + }, + { + // sem invisible variable + inSessionStates: false, + varName: variable.TiDBConfig, + }, + { + // noop variables + stmts: []string{"set sql_buffer_result=true"}, + inSessionStates: true, + varName: "sql_buffer_result", + expectedValue: "1", + }, + { + stmts: []string{"set transaction isolation level repeatable read"}, + inSessionStates: true, + varName: "tx_isolation_one_shot", + expectedValue: "REPEATABLE-READ", + }, + { + inSessionStates: false, + varName: variable.Timestamp, + }, + { + stmts: []string{"set timestamp=100"}, + inSessionStates: true, + varName: variable.Timestamp, + expectedValue: "100", + }, + { + stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000"}, + inSessionStates: true, + varName: variable.RandSeed1, + checkStmt: "select rand()", + expectedValue: "0.028870999839968048", + }, + { + stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000", "select rand()"}, + inSessionStates: true, + varName: variable.RandSeed1, + checkStmt: "select rand()", + expectedValue: "0.11641535266900002", + }, + } + + if !sem.IsEnabled() { + sem.Enable() + defer sem.Disable() + } + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + for _, stmt := range tt.stmts { + if strings.HasPrefix(stmt, "select") { + tk1.MustQuery(stmt) + } else { + tk1.MustExec(stmt) + } + } + tk2 := testkit.NewTestKit(t, store) + rows := tk1.MustQuery("show session_states").Rows() + state := rows[0][0].(string) + msg := fmt.Sprintf("var name: '%s', expected value: '%s'", tt.varName, tt.expectedValue) + require.Equal(t, tt.inSessionStates, strings.Contains(state, tt.varName), msg) + state = strconv.Quote(state) + setSQL := fmt.Sprintf("set session_states %s", state) + tk2.MustExec(setSQL) + if len(tt.expectedValue) > 0 { + checkStmt := tt.checkStmt + if len(checkStmt) == 0 { + checkStmt = fmt.Sprintf("select @@%s", tt.varName) + } + tk2.MustQuery(checkStmt).Check(testkit.Rows(tt.expectedValue)) + } + } + + { + // The session value should not change even if the global value changes. + tk1 := testkit.NewTestKit(t, store) + tk1.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("set global autocommit=0") + tk3 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk3) + tk3.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + } +} + +func TestSessionCtx(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t1(id int)") + + tests := []struct { + setFunc func(tk *testkit.TestKit) any + checkFunc func(tk *testkit.TestKit, param any) + }{ + { + // check PreparedStmtID + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, uint32(1), tk.Session().GetSessionVars().GetNextPreparedStmtID()) + }, + }, + { + // check PreparedStmtID + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("prepare stmt from 'select ?'") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, uint32(2), tk.Session().GetSessionVars().GetNextPreparedStmtID()) + }, + }, + { + // check Status + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, mysql.ServerStatusAutocommit, tk.Session().GetSessionVars().Status&mysql.ServerStatusAutocommit) + }, + }, + { + // check Status + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("set autocommit=0") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + require.Equal(t, uint16(0), tk.Session().GetSessionVars().Status&mysql.ServerStatusAutocommit) + }, + }, + { + // check CurrentDB + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select database()").Check(testkit.Rows("")) + }, + }, + { + // check CurrentDB + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("use test") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select database()").Check(testkit.Rows("test")) + }, + }, + { + // check CurrentDB + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create database test1") + tk.MustExec("use test1") + tk.MustExec("drop database test1") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select database()").Check(testkit.Rows("")) + }, + }, + { + // check LastTxnInfo + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_txn_info").Check(testkit.Rows("")) + }, + }, + { + // check LastTxnInfo + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("begin") + tk.MustExec("insert test.t1 value(1)") + tk.MustExec("commit") + rows := tk.MustQuery("select @@tidb_last_txn_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_txn_info").Check(param.([][]interface{})) + }, + }, + { + // check LastQueryInfo + setFunc: func(tk *testkit.TestKit) any { + rows := tk.MustQuery("select @@tidb_last_query_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_query_info").Check(param.([][]interface{})) + }, + }, + { + // check LastQueryInfo + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select * from test.t1") + startTS := tk.Session().GetSessionVars().LastQueryInfo.StartTS + require.NotEqual(t, uint64(0), startTS) + return startTS + }, + checkFunc: func(tk *testkit.TestKit, param any) { + startTS := tk.Session().GetSessionVars().LastQueryInfo.StartTS + require.Equal(t, param.(uint64), startTS) + }, + }, + { + // check LastDDLInfo + setFunc: func(tk *testkit.TestKit) any { + rows := tk.MustQuery("select @@tidb_last_ddl_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_ddl_info").Check(param.([][]interface{})) + }, + }, + { + // check LastDDLInfo + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("truncate table test.t1") + rows := tk.MustQuery("select @@tidb_last_ddl_info").Rows() + require.NotEqual(t, "", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@tidb_last_ddl_info").Check(param.([][]interface{})) + }, + }, + { + // check LastFoundRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert test.t1 value(1), (2), (3), (4), (5)") + // SQL_CALC_FOUND_ROWS is not supported now, so we just test normal select. + rows := tk.MustQuery("select * from test.t1 limit 3").Rows() + require.Equal(t, 3, len(rows)) + return "3" + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select found_rows()").Check(testkit.Rows(param.(string))) + }, + }, + { + // check SequenceState + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create sequence test.s") + tk.MustQuery("select nextval(test.s)").Check(testkit.Rows("1")) + tk.MustQuery("select lastval(test.s)").Check(testkit.Rows("1")) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select lastval(test.s)").Check(testkit.Rows("1")) + tk.MustQuery("select nextval(test.s)").Check(testkit.Rows("2")) + }, + }, + { + // check FoundInPlanCache + setFunc: func(tk *testkit.TestKit) any { + require.False(t, tk.Session().GetSessionVars().FoundInPlanCache) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + }, + }, + { + // check FoundInPlanCache + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("prepare stmt from 'select * from test.t1'") + tk.MustQuery("execute stmt") + tk.MustQuery("execute stmt") + require.True(t, tk.Session().GetSessionVars().FoundInPlanCache) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + }, + }, + { + // check FoundInBinding + setFunc: func(tk *testkit.TestKit) any { + require.False(t, tk.Session().GetSessionVars().FoundInBinding) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("0")) + }, + }, + { + // check FoundInBinding + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1") + tk.MustQuery("select * from test.t1") + require.True(t, tk.Session().GetSessionVars().FoundInBinding) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1")) + }, + }, + } + + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + var param any + if tt.setFunc != nil { + param = tt.setFunc(tk1) + } + tk2 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk2) + tt.checkFunc(tk2, param) + } +} + +func TestStatementCtx(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t1(id int auto_increment primary key, str char(1))") + + tests := []struct { + setFunc func(tk *testkit.TestKit) any + checkFunc func(tk *testkit.TestKit, param any) + }{ + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("show warnings") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("0")) + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select 1") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert into test.t1(str) value('a'), ('b'), ('c')") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("3")) + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastInsertID + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_insert_id").Check(testkit.Rows("0")) + }, + }, + { + // check LastInsertID + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert into test.t1(str) value('d')") + rows := tk.MustQuery("select @@last_insert_id").Rows() + require.NotEqual(t, "0", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_insert_id").Check(param.([][]any)) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select 1") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("show errors").Check(testkit.Rows()) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("show errors").Check(testkit.Rows()) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("0 0")) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustGetErrCode("insert into test.t1(str) value('ef')", errno.ErrDataTooLong) + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, 1, len(rows)) + tk.MustQuery("show errors").Check(rows) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(param.([][]any)) + tk.MustQuery("show errors").Check(param.([][]any)) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("1 1")) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("set sql_mode=''") + tk.MustExec("insert into test.t1(str) value('ef'), ('ef')") + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, 2, len(rows)) + tk.MustQuery("show errors").Check(testkit.Rows()) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(param.([][]any)) + tk.MustQuery("show errors").Check(testkit.Rows()) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("2 0")) + }, + }, + } + + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + var param any + if tt.setFunc != nil { + param = tt.setFunc(tk1) + } + tk2 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk2) + tt.checkFunc(tk2, param) + } +} + +func TestPreparedStatements(t *testing.T) { + store := testkit.CreateMockStore(t) + sv := server.CreateMockServer(t, store) + defer sv.Close() + + tests := []struct { + setFunc func(tk *testkit.TestKit, conn server.MockConn) any + checkFunc func(tk *testkit.TestKit, conn server.MockConn, param any) + restoreErr int + cleanFunc func(tk *testkit.TestKit) + }{ + { + // no such statement + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustGetErrCode("execute stmt", errno.ErrPreparedStmtNotFound) + }, + }, + { + // deallocate it after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("prepare stmt from 'select 1'") + tk.MustExec("deallocate prepare stmt") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustGetErrCode("execute stmt", errno.ErrPreparedStmtNotFound) + }, + }, + { + // statement with no parameters + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert into test.t1 value(1), (2), (3)") + tk.MustExec("prepare stmt from 'select * from test.t1 order by id'") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt").Check(testkit.Rows("1", "2", "3")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // statement with user-defined parameters + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert into test.t1 value(1), (2), (3)") + tk.MustExec("prepare stmt from 'select * from test.t1 where id>? order by id limit ?'") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustExec("set @a=1, @b=1") + tk.MustQuery("execute stmt using @a, @b").Check(testkit.Rows("2")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // execute the statement multiple times + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("prepare stmt1 from 'insert into test.t1 value(?), (?), (?)'") + tk.MustExec("prepare stmt2 from 'select * from test.t1 order by id'") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt2").Check(testkit.Rows()) + tk.MustExec("set @a=1, @b=2, @c=3") + tk.MustExec("execute stmt1 using @a, @b, @c") + tk.MustQuery("execute stmt2").Check(testkit.Rows("1", "2", "3")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // update session variables after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci") + tk.MustExec("prepare stmt from 'select @@character_set_client, @@collation_connection'") + tk.MustQuery("execute stmt").Check(testkit.Rows("utf8mb4 utf8mb4_general_ci")) + tk.MustExec("set names gbk collate gbk_chinese_ci") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt").Check(testkit.Rows("gbk gbk_chinese_ci")) + }, + }, + { + // session-scoped ANSI_QUOTES + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("set sql_mode='ANSI_QUOTES'") + tk.MustExec("prepare stmt from 'select \\'a\\''") + tk.MustQuery("execute stmt").Check(testkit.Rows("a")) + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt").Check(testkit.Rows("a")) + }, + }, + { + // global-scoped ANSI_QUOTES + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("set global sql_mode='ANSI_QUOTES'") + tk.MustExec("prepare stmt from \"select \\\"a\\\"\"") + tk.MustQuery("execute stmt").Check(testkit.Rows("a")) + return nil + }, + restoreErr: errno.ErrBadField, + }, + { + // statement name + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("prepare `stmt 1` from 'select 1'") + tk.MustQuery("execute `stmt 1`").Check(testkit.Rows("1")) + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute `stmt 1`").Check(testkit.Rows("1")) + }, + }, + { + // multiple prepared statements + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert into test.t1 value(1), (2), (3)") + tk.MustExec("prepare stmt1 from 'select * from test.t1 order by id'") + tk.MustExec("prepare stmt2 from 'select * from test.t1 where id=?'") + tk.MustExec("set @a=1") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt1").Check(testkit.Rows("1", "2", "3")) + tk.MustQuery("execute stmt2 using @a").Check(testkit.Rows("1")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // change current db after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("use test") + tk.MustExec("create table t1(id int)") + tk.MustExec("insert into t1 value(1), (2), (3)") + tk.MustExec("prepare stmt from 'select * from t1 order by id'") + tk.MustExec("use mysql") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("select database()").Check(testkit.Rows("mysql")) + tk.MustQuery("execute stmt").Check(testkit.Rows("1", "2", "3")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // update user variable after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert into test.t1 value(1), (2), (3)") + tk.MustExec("set @a=1") + tk.MustExec("prepare stmt from 'select * from test.t1 where id=?'") + tk.MustQuery("execute stmt using @a").Check(testkit.Rows("1")) + tk.MustExec("set @a=2") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt using @a").Check(testkit.Rows("2")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // alter table after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert into test.t1 value(1)") + tk.MustExec("prepare stmt from 'select * from test.t1'") + tk.MustExec("alter table test.t1 add column c char(1) default 'a'") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt").Check(testkit.Rows("1 a")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // drop and create table after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("prepare stmt from 'select * from test.t1'") + tk.MustExec("drop table test.t1") + tk.MustExec("create table test.t1(id int, c char(1))") + tk.MustExec("insert into test.t1 value(1, 'a')") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt").Check(testkit.Rows("1 a")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // drop table after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("prepare stmt from 'select * from test.t1'") + tk.MustExec("drop table test.t1") + return nil + }, + restoreErr: errno.ErrNoSuchTable, + }, + { + // drop db after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create database test1") + tk.MustExec("use test1") + tk.MustExec("create table t1(id int)") + tk.MustExec("prepare stmt from 'select * from t1'") + tk.MustExec("drop database test1") + return nil + }, + restoreErr: errno.ErrNoSuchTable, + }, + { + // update sql_mode after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("set sql_mode=''") + tk.MustExec("create table test.t1(id int, name char(10))") + tk.MustExec("insert into test.t1 value(1, 'a')") + tk.MustExec("prepare stmt from 'select id, name from test.t1 group by id'") + tk.MustExec("set sql_mode='ONLY_FULL_GROUP_BY'") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + // The prepare statement is decoded after decoding session variables, + // so `SET SESSION_STATES` won't report errors. + tk.MustGetErrCode("execute stmt", errno.ErrFieldNotInGroupBy) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // update global sql_mode after prepare + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("set sql_mode=''") + tk.MustExec("create table test.t1(id int, name char(10))") + tk.MustExec("prepare stmt from 'select id, name from test.t1 group by id'") + tk.MustExec("set global sql_mode='ONLY_FULL_GROUP_BY'") + return nil + }, + restoreErr: errno.ErrFieldNotInGroupBy, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + tk.MustExec("set global sql_mode=default") + }, + }, + { + // warnings won't be affected + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + // Decoding this prepared statement should report a warning. + tk.MustExec("prepare stmt from 'select 0/0'") + // Override the warning. + tk.MustQuery("select 1") + return nil + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("show warnings").Check(testkit.Rows()) + }, + }, + { + // test binary-protocol prepared statement + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + stmtID, _, _, err := tk.Session().PrepareStmt("select ?") + require.NoError(t, err) + return stmtID + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + rs, err := tk.Session().ExecutePreparedStmt(context.Background(), param.(uint32), expression.Args2Expressions4Test(1)) + require.NoError(t, err) + tk.ResultSetToResult(rs, "").Check(testkit.Rows("1")) + }, + }, + { + // no such prepared statement + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + _, err := tk.Session().ExecutePreparedStmt(context.Background(), 1, nil) + errEqualsCode(t, err, errno.ErrPreparedStmtNotFound) + }, + }, + { + // both text and binary protocols + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("prepare stmt from 'select 10'") + stmtID, _, _, err := tk.Session().PrepareStmt("select ?") + require.NoError(t, err) + return stmtID + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + tk.MustQuery("execute stmt").Check(testkit.Rows("10")) + rs, err := tk.Session().ExecutePreparedStmt(context.Background(), param.(uint32), expression.Args2Expressions4Test(1)) + require.NoError(t, err) + tk.ResultSetToResult(rs, "").Check(testkit.Rows("1")) + }, + }, + { + // drop binary protocol statements + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + stmtID, _, _, err := tk.Session().PrepareStmt("select ?") + require.NoError(t, err) + return stmtID + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + err := tk.Session().DropPreparedStmt(param.(uint32)) + require.NoError(t, err) + }, + }, + { + // execute the statement multiple times + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + tk.MustExec("create table test.t1(id int)") + stmtID1, _, _, err := tk.Session().PrepareStmt("insert into test.t1 value(?), (?), (?)") + require.NoError(t, err) + stmtID2, _, _, err := tk.Session().PrepareStmt("select * from test.t1 order by id") + require.NoError(t, err) + return []uint32{stmtID1, stmtID2} + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + stmtIDs := param.([]uint32) + rs, err := tk.Session().ExecutePreparedStmt(context.Background(), stmtIDs[1], nil) + require.NoError(t, err) + tk.ResultSetToResult(rs, "").Check(testkit.Rows()) + _, err = tk.Session().ExecutePreparedStmt(context.Background(), stmtIDs[0], expression.Args2Expressions4Test(1, 2, 3)) + require.NoError(t, err) + rs, err = tk.Session().ExecutePreparedStmt(context.Background(), stmtIDs[1], nil) + require.NoError(t, err) + tk.ResultSetToResult(rs, "").Check(testkit.Rows("1", "2", "3")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // the latter stmt ID should be bigger + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + stmtID, _, _, err := tk.Session().PrepareStmt("select ?") + require.NoError(t, err) + return stmtID + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + stmtID, _, _, err := tk.Session().PrepareStmt("select ?") + require.NoError(t, err) + require.True(t, stmtID > param.(uint32)) + }, + }, + { + // execute the statement with cursor + setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select ?")...) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getExecuteBytes(1, true, true, paramInfo{value: 1, isNull: false}) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getFetchBytes(1, 10) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + // This COM_STMT_FETCH returns EOF. + cmd = getFetchBytes(1, 10) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + return uint32(1) + }, + checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { + cmd := getExecuteBytes(param.(uint32), false, false, paramInfo{value: 1, isNull: false}) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + }, + }, + // Skip this case. Refer to https://github.com/pingcap/tidb/issues/35784. + //{ + // // update privilege after prepare + // setFunc: func(tk *testkit.TestKit, conn server.MockConn) any { + // rootTk := testkit.NewTestKit(t, store) + // rootTk.MustExec(`CREATE USER 'u1'@'localhost'`) + // rootTk.MustExec("create table test.t1(id int)") + // require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil)) + // rootTk.MustExec(`GRANT SELECT ON test.t1 TO 'u1'@'localhost'`) + // tk.MustExec("prepare stmt from 'select * from test.t1'") + // rootTk.MustExec(`REVOKE SELECT ON test.t1 FROM 'u1'@'localhost'`) + // return nil + // }, + // prepareFunc: func(tk *testkit.TestKit, conn server.MockConn) { + // require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil)) + // }, + // restoreErr: errno.ErrNoSuchTable, + // cleanFunc: func(tk *testkit.TestKit) { + // rootTk := testkit.NewTestKit(t, store) + // rootTk.MustExec("drop user 'u1'@'localhost'") + // rootTk.MustExec("drop table test.t1") + // }, + //}, + } + + for _, tt := range tests { + conn1 := server.CreateMockConn(t, sv) + tk1 := testkit.NewTestKitWithSession(t, store, conn1.Context().Session) + conn1.Context().Session.GetSessionVars().User = nil + var param any + if tt.setFunc != nil { + param = tt.setFunc(tk1, conn1) + } + conn2 := server.CreateMockConn(t, sv) + tk2 := testkit.NewTestKitWithSession(t, store, conn2.Context().Session) + rows := tk1.MustQuery("show session_states").Rows() + require.Len(t, rows, 1) + state := rows[0][0].(string) + state = strings.ReplaceAll(state, "\\", "\\\\") + state = strings.ReplaceAll(state, "'", "\\'") + setSQL := fmt.Sprintf("set session_states '%s'", state) + if tt.restoreErr != 0 { + tk2.MustGetErrCode(setSQL, tt.restoreErr) + } else { + tk2.MustExec(setSQL) + tt.checkFunc(tk2, conn2, param) + } + if tt.cleanFunc != nil { + tt.cleanFunc(tk1) + } + conn1.Close() + conn2.Close() + } +} + +func TestSQLBinding(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t1(id int primary key, name varchar(10), key(name))") + + tests := []struct { + setFunc func(tk *testkit.TestKit) any + checkFunc func(tk *testkit.TestKit, param any) + restoreErr int + cleanFunc func(tk *testkit.TestKit) + }{ + { + // no bindings + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show session bindings").Check(testkit.Rows()) + }, + }, + { + // use binding and drop it + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1 use index(name)") + rows := tk.MustQuery("show session bindings").Rows() + require.Equal(t, 1, len(rows)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show session bindings").Check(param.([][]any)) + require.True(t, tk.HasPlan("select * from test.t1", "IndexFullScan")) + tk.MustExec("drop session binding for select * from test.t1") + tk.MustQuery("show session bindings").Check(testkit.Rows()) + }, + }, + { + // use hint + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select /*+ use_index(test.t1, name) */ * from test.t1") + rows := tk.MustQuery("show session bindings").Rows() + require.Equal(t, 1, len(rows)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show session bindings").Check(param.([][]any)) + require.True(t, tk.HasPlan("select * from test.t1", "IndexFullScan")) + tk.MustExec("drop session binding for select * from test.t1") + tk.MustQuery("show session bindings").Check(testkit.Rows()) + }, + }, + { + // drop binding + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1 use index(name)") + tk.MustExec("drop session binding for select * from test.t1") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show session bindings").Check(testkit.Rows()) + }, + }, + { + // default db + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("use test") + tk.MustExec("create session binding for select * from t1 using select * from t1 use index(name)") + tk.MustExec("use mysql") + rows := tk.MustQuery("show session bindings").Rows() + require.Equal(t, 1, len(rows)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show session bindings").Check(param.([][]any)) + require.True(t, tk.HasPlan("select * from test.t1", "IndexFullScan")) + }, + }, + { + // drop table + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1 use index(name)") + tk.MustExec("drop table test.t1") + return nil + }, + restoreErr: errno.ErrNoSuchTable, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("create table test.t1(id int primary key, name varchar(10), key(name))") + }, + }, + { + // drop db + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create database test1") + tk.MustExec("use test1") + tk.MustExec("create table t1(id int primary key, name varchar(10), key(name))") + tk.MustExec("create session binding for select * from t1 using select /*+ use_index(t1, name) */ * from t1") + tk.MustExec("drop database test1") + return nil + }, + restoreErr: errno.ErrNoSuchTable, + }, + { + // alter the table + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1 use index(name)") + tk.MustExec("alter table test.t1 drop index name") + return nil + }, + restoreErr: errno.ErrKeyDoesNotExist, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("alter table test.t1 add index name(name)") + }, + }, + { + // both global and session bindings + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create global binding for select * from test.t1 using select * from test.t1 use index(primary)") + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1 use index(name)") + sessionRows := tk.MustQuery("show bindings").Rows() + require.Equal(t, 1, len(sessionRows)) + globalRows := tk.MustQuery("show global bindings").Rows() + require.Equal(t, 1, len(globalRows)) + return [][][]any{sessionRows, globalRows} + }, + checkFunc: func(tk *testkit.TestKit, param any) { + rows := param.([][][]any) + tk.MustQuery("show bindings").Check(rows[0]) + tk.MustQuery("show global bindings").Check(rows[1]) + require.True(t, tk.HasPlan("select * from test.t1", "IndexFullScan")) + tk.MustExec("drop session binding for select * from test.t1") + require.True(t, tk.HasPlan("select * from test.t1", "TableFullScan")) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop global binding for select * from test.t1") + }, + }, + { + // multiple bindings + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("create session binding for select * from test.t1 using select * from test.t1 use index(name)") + tk.MustExec("create session binding for select count(*) from test.t1 using select count(*) from test.t1 use index(primary)") + tk.MustExec("create session binding for select name from test.t1 using select name from test.t1 use index(primary)") + rows := tk.MustQuery("show bindings").Rows() + require.Equal(t, 3, len(rows)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show bindings").Check(param.([][]any)) + require.True(t, tk.HasPlan("select * from test.t1", "IndexFullScan")) + }, + }, + } + + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + var param any + if tt.setFunc != nil { + param = tt.setFunc(tk1) + } + rows := tk1.MustQuery("show session_states").Rows() + require.Len(t, rows, 1) + state := rows[0][0].(string) + state = strconv.Quote(state) + setSQL := fmt.Sprintf("set session_states %s", state) + tk2 := testkit.NewTestKit(t, store) + if tt.restoreErr != 0 { + tk2.MustGetErrCode(setSQL, tt.restoreErr) + } else { + tk2.MustExec(setSQL) + tt.checkFunc(tk2, param) + } + if tt.cleanFunc != nil { + tt.cleanFunc(tk1) + } + } +} + +func TestShowStateFail(t *testing.T) { + store := testkit.CreateMockStore(t) + sv := server.CreateMockServer(t, store) + defer sv.Close() + + tests := []struct { + setFunc func(tk *testkit.TestKit, conn server.MockConn) + showErr int + cleanFunc func(tk *testkit.TestKit) + }{ + { + // in an active transaction + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("begin") + }, + showErr: errno.ErrCannotMigrateSession, + }, + { + // out of transaction + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("begin") + tk.MustExec("commit") + }, + }, + { + // created a global temporary table + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create global temporary table test.t1(id int) on commit delete rows") + tk.MustExec("insert into test.t1 value(1)") + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // created a local temporary table + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create temporary table test.t1(id int)") + }, + showErr: errno.ErrCannotMigrateSession, + }, + { + // drop the local temporary table + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create temporary table test.t1(id int)") + tk.MustExec("drop table test.t1") + }, + }, + { + // hold and advisory lock + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustQuery("SELECT get_lock('testlock1', 0)").Check(testkit.Rows("1")) + }, + showErr: errno.ErrCannotMigrateSession, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustQuery("SELECT release_lock('testlock1')").Check(testkit.Rows("1")) + }, + }, + { + // release the advisory lock + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustQuery("SELECT get_lock('testlock1', 0)").Check(testkit.Rows("1")) + tk.MustQuery("SELECT release_lock('testlock1')").Check(testkit.Rows("1")) + }, + }, + { + // hold table locks + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("lock tables test.t1 write") + tk.MustQuery("show warnings").Check(testkit.Rows()) + }, + showErr: errno.ErrCannotMigrateSession, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // unlock the tables + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("lock tables test.t1 write") + tk.MustExec("unlock tables") + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // after COM_STMT_SEND_LONG_DATA + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select ?")...) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getLongDataBytes(1, 0, []byte("abc")) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + }, + showErr: errno.ErrCannotMigrateSession, + }, + { + // after COM_STMT_SEND_LONG_DATA and COM_STMT_EXECUTE + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select ?")...) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getLongDataBytes(1, 0, []byte("abc")) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getExecuteBytes(1, false, true, paramInfo{value: 1, isNull: false}) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + }, + }, + { + // query with cursor, and data is not fetched + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert test.t1 value(1), (2), (3)") + cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select * from test.t1")...) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getExecuteBytes(1, true, false) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + }, + showErr: errno.ErrCannotMigrateSession, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // fetched all the data but the EOF packet is not sent + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert test.t1 value(1), (2), (3)") + cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select * from test.t1")...) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getExecuteBytes(1, true, false) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getFetchBytes(1, 10) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + }, + showErr: errno.ErrCannotMigrateSession, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // EOF is sent + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert test.t1 value(1), (2), (3)") + cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select * from test.t1")...) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getExecuteBytes(1, true, false) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getFetchBytes(1, 10) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + // This COM_STMT_FETCH returns EOF. + cmd = getFetchBytes(1, 10) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + { + // statement is reset + setFunc: func(tk *testkit.TestKit, conn server.MockConn) { + tk.MustExec("create table test.t1(id int)") + tk.MustExec("insert test.t1 value(1), (2), (3)") + cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select * from test.t1")...) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getExecuteBytes(1, true, false) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + cmd = getResetBytes(1) + require.NoError(t, conn.Dispatch(context.Background(), cmd)) + }, + cleanFunc: func(tk *testkit.TestKit) { + tk.MustExec("drop table test.t1") + }, + }, + } + + defer config.RestoreFunc()() + config.UpdateGlobal(func(conf *config.Config) { + conf.EnableTableLock = true + }) + for _, tt := range tests { + conn1 := server.CreateMockConn(t, sv) + conn1.Context().Session.GetSessionVars().User = nil + tk1 := testkit.NewTestKitWithSession(t, store, conn1.Context().Session) + tt.setFunc(tk1, conn1) + if tt.showErr == 0 { + tk2 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk2) + } else { + err := tk1.QueryToErr("show session_states") + errEqualsCode(t, err, tt.showErr) + } + if tt.cleanFunc != nil { + tt.cleanFunc(tk1) + } + conn1.Close() + } +} + +func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) { + rows := tk1.MustQuery("show session_states").Rows() + require.Len(t, rows, 1) + state := rows[0][0].(string) + state = strconv.Quote(state) + setSQL := fmt.Sprintf("set session_states %s", state) + tk2.MustExec(setSQL) +} + +func errEqualsCode(t *testing.T, err error, code int) { + require.NotNil(t, err) + originErr := errors.Cause(err) + tErr, ok := originErr.(*terror.Error) + require.True(t, ok) + sqlErr := terror.ToSQLError(tErr) + require.Equal(t, code, int(sqlErr.Code)) +} + +// create bytes for COM_STMT_SEND_LONG_DATA +func getLongDataBytes(stmtID uint32, paramID uint16, param []byte) []byte { + buf := make([]byte, 7+len(param)) + pos := 0 + buf[pos] = mysql.ComStmtSendLongData + pos++ + binary.LittleEndian.PutUint32(buf[pos:], stmtID) + pos += 4 + binary.LittleEndian.PutUint16(buf[pos:], paramID) + pos += 2 + buf = append(buf[:pos], param...) + return buf +} + +type paramInfo struct { + value uint32 + isNull bool +} + +// create bytes for COM_STMT_EXECUTE. It only supports int type for convenience. +func getExecuteBytes(stmtID uint32, useCursor bool, newParam bool, params ...paramInfo) []byte { + nullBitmapLen := (len(params) + 7) >> 3 + buf := make([]byte, 11+nullBitmapLen+len(params)*6) + pos := 0 + buf[pos] = mysql.ComStmtExecute + pos++ + binary.LittleEndian.PutUint32(buf[pos:], stmtID) + pos += 4 + if useCursor { + buf[pos] = 1 + } + pos++ + binary.LittleEndian.PutUint32(buf[pos:], 1) + pos += 4 + for i, param := range params { + if param.isNull { + buf[pos+(i>>3)] |= 1 << (i % 8) + } + } + pos += nullBitmapLen + if newParam { + buf[pos] = 1 + pos++ + for i := 0; i < len(params); i++ { + buf[pos] = mysql.TypeLong + pos++ + buf[pos] = 0 + pos++ + } + } else { + buf[pos] = 0 + pos++ + } + for _, param := range params { + if !param.isNull { + binary.LittleEndian.PutUint32(buf[pos:], param.value) + pos += 4 + } + } + return buf[:pos] +} + +// create bytes for COM_STMT_FETCH. +func getFetchBytes(stmtID, fetchSize uint32) []byte { + buf := make([]byte, 9) + pos := 0 + buf[pos] = mysql.ComStmtFetch + pos++ + binary.LittleEndian.PutUint32(buf[pos:], stmtID) + pos += 4 + binary.LittleEndian.PutUint32(buf[pos:], fetchSize) + return buf +} + +// create bytes for COM_STMT_FETCH. +func getResetBytes(stmtID uint32) []byte { + buf := make([]byte, 5) + pos := 0 + buf[pos] = mysql.ComStmtReset + pos++ + binary.LittleEndian.PutUint32(buf[pos:], stmtID) + return buf +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index ae6d14df23f6b..0eddedd61fb7e 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -976,9 +976,12 @@ type SessionVars struct { // TemporaryTableData stores committed kv values for temporary table for current session. TemporaryTableData TemporaryTableData +<<<<<<< HEAD // MPPStoreLastFailTime records the lastest fail time that a TiFlash store failed. MPPStoreLastFailTime map[string]time.Time +======= +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) // MPPStoreFailTTL indicates the duration that protect TiDB from sending task to a new recovered TiFlash. MPPStoreFailTTL string @@ -1148,6 +1151,7 @@ type ConnectionInfo struct { // NewSessionVars creates a session vars object. func NewSessionVars() *SessionVars { vars := &SessionVars{ +<<<<<<< HEAD Users: make(map[string]types.Datum), UserVarTypes: make(map[string]*types.FieldType), systems: make(map[string]string), @@ -1237,6 +1241,111 @@ func NewSessionVars() *SessionVars { EnablePlacementChecks: DefEnablePlacementCheck, Rng: utilMath.NewWithTime(), StatsLoadSyncWait: StatsLoadSyncWait.Load(), +======= + userVars: struct { + lock sync.RWMutex + values map[string]types.Datum + types map[string]*types.FieldType + }{ + values: make(map[string]types.Datum), + types: make(map[string]*types.FieldType), + }, + systems: make(map[string]string), + stmtVars: make(map[string]string), + PreparedStmts: make(map[uint32]interface{}), + PreparedStmtNameToID: make(map[string]uint32), + PreparedParams: make([]types.Datum, 0, 10), + TxnCtx: &TransactionContext{}, + RetryInfo: &RetryInfo{}, + ActiveRoles: make([]*auth.RoleIdentity, 0, 10), + StrictSQLMode: true, + AutoIncrementIncrement: DefAutoIncrementIncrement, + AutoIncrementOffset: DefAutoIncrementOffset, + Status: mysql.ServerStatusAutocommit, + StmtCtx: new(stmtctx.StatementContext), + AllowAggPushDown: false, + AllowCartesianBCJ: DefOptCartesianBCJ, + MPPOuterJoinFixedBuildSide: DefOptMPPOuterJoinFixedBuildSide, + BroadcastJoinThresholdSize: DefBroadcastJoinThresholdSize, + BroadcastJoinThresholdCount: DefBroadcastJoinThresholdSize, + OptimizerSelectivityLevel: DefTiDBOptimizerSelectivityLevel, + EnableOuterJoinReorder: DefTiDBEnableOuterJoinReorder, + RetryLimit: DefTiDBRetryLimit, + DisableTxnAutoRetry: DefTiDBDisableTxnAutoRetry, + DDLReorgPriority: kv.PriorityLow, + allowInSubqToJoinAndAgg: DefOptInSubqToJoinAndAgg, + preferRangeScan: DefOptPreferRangeScan, + EnableCorrelationAdjustment: DefOptEnableCorrelationAdjustment, + LimitPushDownThreshold: DefOptLimitPushDownThreshold, + CorrelationThreshold: DefOptCorrelationThreshold, + CorrelationExpFactor: DefOptCorrelationExpFactor, + cpuFactor: DefOptCPUFactor, + copCPUFactor: DefOptCopCPUFactor, + CopTiFlashConcurrencyFactor: DefOptTiFlashConcurrencyFactor, + networkFactor: DefOptNetworkFactor, + scanFactor: DefOptScanFactor, + descScanFactor: DefOptDescScanFactor, + seekFactor: DefOptSeekFactor, + memoryFactor: DefOptMemoryFactor, + diskFactor: DefOptDiskFactor, + concurrencyFactor: DefOptConcurrencyFactor, + enableForceInlineCTE: DefOptForceInlineCTE, + EnableVectorizedExpression: DefEnableVectorizedExpression, + CommandValue: uint32(mysql.ComSleep), + TiDBOptJoinReorderThreshold: DefTiDBOptJoinReorderThreshold, + SlowQueryFile: config.GetGlobalConfig().Log.SlowQueryFile, + WaitSplitRegionFinish: DefTiDBWaitSplitRegionFinish, + WaitSplitRegionTimeout: DefWaitSplitRegionTimeout, + enableIndexMerge: DefTiDBEnableIndexMerge, + NoopFuncsMode: TiDBOptOnOffWarn(DefTiDBEnableNoopFuncs), + replicaRead: kv.ReplicaReadLeader, + AllowRemoveAutoInc: DefTiDBAllowRemoveAutoInc, + UsePlanBaselines: DefTiDBUsePlanBaselines, + EvolvePlanBaselines: DefTiDBEvolvePlanBaselines, + EnableExtendedStats: false, + IsolationReadEngines: make(map[kv.StoreType]struct{}), + LockWaitTimeout: DefInnodbLockWaitTimeout * 1000, + MetricSchemaStep: DefTiDBMetricSchemaStep, + MetricSchemaRangeDuration: DefTiDBMetricSchemaRangeDuration, + SequenceState: NewSequenceState(), + WindowingUseHighPrecision: true, + PrevFoundInPlanCache: DefTiDBFoundInPlanCache, + FoundInPlanCache: DefTiDBFoundInPlanCache, + PrevFoundInBinding: DefTiDBFoundInBinding, + FoundInBinding: DefTiDBFoundInBinding, + SelectLimit: math.MaxUint64, + AllowAutoRandExplicitInsert: DefTiDBAllowAutoRandExplicitInsert, + EnableClusteredIndex: DefTiDBEnableClusteredIndex, + EnableParallelApply: DefTiDBEnableParallelApply, + ShardAllocateStep: DefTiDBShardAllocateStep, + EnableAmendPessimisticTxn: DefTiDBEnableAmendPessimisticTxn, + PartitionPruneMode: *atomic2.NewString(DefTiDBPartitionPruneMode), + TxnScope: kv.NewDefaultTxnScopeVar(), + EnabledRateLimitAction: DefTiDBEnableRateLimitAction, + EnableAsyncCommit: DefTiDBEnableAsyncCommit, + Enable1PC: DefTiDBEnable1PC, + GuaranteeLinearizability: DefTiDBGuaranteeLinearizability, + AnalyzeVersion: DefTiDBAnalyzeVersion, + EnableIndexMergeJoin: DefTiDBEnableIndexMergeJoin, + AllowFallbackToTiKV: make(map[kv.StoreType]struct{}), + CTEMaxRecursionDepth: DefCTEMaxRecursionDepth, + TMPTableSize: DefTiDBTmpTableMaxSize, + MPPStoreFailTTL: DefTiDBMPPStoreFailTTL, + Rng: mathutil.NewWithTime(), + StatsLoadSyncWait: StatsLoadSyncWait.Load(), + EnableLegacyInstanceScope: DefEnableLegacyInstanceScope, + RemoveOrderbyInSubquery: DefTiDBRemoveOrderbyInSubquery, + EnableSkewDistinctAgg: DefTiDBSkewDistinctAgg, + Enable3StageDistinctAgg: DefTiDB3StageDistinctAgg, + MaxAllowedPacket: DefMaxAllowedPacket, + TiFlashFastScan: DefTiFlashFastScan, + EnableTiFlashReadForWriteStmt: DefTiDBEnableTiFlashReadForWriteStmt, + ForeignKeyChecks: DefTiDBForeignKeyChecks, + HookContext: hctx, + EnableReuseCheck: DefTiDBEnableReusechunk, + preUseChunkAlloc: DefTiDBUseAlloc, + ChunkPool: ReuseChunkPool{Alloc: nil}, +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) } vars.KVVars = tikvstore.NewVariables(&vars.Killed) vars.Concurrency = Concurrency{ @@ -1695,6 +1804,81 @@ func (s *SessionVars) GetTemporaryTable(tblInfo *model.TableInfo) tableutil.Temp return nil } +<<<<<<< HEAD +======= +// EncodeSessionStates saves session states into SessionStates. +func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *sessionstates.SessionStates) (err error) { + // Encode user-defined variables. + sessionStates.UserVars = make(map[string]*types.Datum, len(s.userVars.values)) + sessionStates.UserVarTypes = make(map[string]*ptypes.FieldType, len(s.userVars.types)) + s.userVars.lock.RLock() + defer s.userVars.lock.RUnlock() + for name, userVar := range s.userVars.values { + sessionStates.UserVars[name] = userVar.Clone() + } + for name, userVarType := range s.userVars.types { + sessionStates.UserVarTypes[name] = userVarType.Clone() + } + + // Encode other session contexts. + sessionStates.PreparedStmtID = s.preparedStmtID + sessionStates.Status = s.Status + sessionStates.CurrentDB = s.CurrentDB + sessionStates.LastTxnInfo = s.LastTxnInfo + if s.LastQueryInfo.StartTS != 0 { + sessionStates.LastQueryInfo = &s.LastQueryInfo + } + if s.LastDDLInfo.SeqNum != 0 { + sessionStates.LastDDLInfo = &s.LastDDLInfo + } + sessionStates.LastFoundRows = s.LastFoundRows + sessionStates.SequenceLatestValues = s.SequenceState.GetAllStates() + sessionStates.FoundInPlanCache = s.PrevFoundInPlanCache + sessionStates.FoundInBinding = s.PrevFoundInBinding + + // Encode StatementContext. We encode it here to avoid circle dependency. + sessionStates.LastAffectedRows = s.StmtCtx.PrevAffectedRows + sessionStates.LastInsertID = s.StmtCtx.PrevLastInsertID + sessionStates.Warnings = s.StmtCtx.GetWarnings() + return +} + +// DecodeSessionStates restores session states from SessionStates. +func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *sessionstates.SessionStates) (err error) { + // Decode user-defined variables. + s.userVars.values = make(map[string]types.Datum, len(sessionStates.UserVars)) + for name, userVar := range sessionStates.UserVars { + s.SetUserVarVal(name, *userVar.Clone()) + } + s.userVars.types = make(map[string]*ptypes.FieldType, len(sessionStates.UserVarTypes)) + for name, userVarType := range sessionStates.UserVarTypes { + s.SetUserVarType(name, userVarType.Clone()) + } + + // Decode other session contexts. + s.preparedStmtID = sessionStates.PreparedStmtID + s.Status = sessionStates.Status + s.CurrentDB = sessionStates.CurrentDB + s.LastTxnInfo = sessionStates.LastTxnInfo + if sessionStates.LastQueryInfo != nil { + s.LastQueryInfo = *sessionStates.LastQueryInfo + } + if sessionStates.LastDDLInfo != nil { + s.LastDDLInfo = *sessionStates.LastDDLInfo + } + s.LastFoundRows = sessionStates.LastFoundRows + s.SequenceState.SetAllStates(sessionStates.SequenceLatestValues) + s.FoundInPlanCache = sessionStates.FoundInPlanCache + s.FoundInBinding = sessionStates.FoundInBinding + + // Decode StatementContext. + s.StmtCtx.SetAffectedRows(uint64(sessionStates.LastAffectedRows)) + s.StmtCtx.PrevLastInsertID = sessionStates.LastInsertID + s.StmtCtx.SetWarnings(sessionStates.Warnings) + return +} + +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) // TableDelta stands for the changed count for one table or partition. type TableDelta struct { Delta int64 diff --git a/store/copr/BUILD.bazel b/store/copr/BUILD.bazel new file mode 100644 index 0000000000000..f6cbe57efa2d7 --- /dev/null +++ b/store/copr/BUILD.bazel @@ -0,0 +1,91 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "copr", + srcs = [ + "batch_coprocessor.go", + "batch_request_sender.go", + "coprocessor.go", + "coprocessor_cache.go", + "key_ranges.go", + "mpp.go", + "mpp_probe.go", + "region_cache.go", + "store.go", + ], + importpath = "github.com/pingcap/tidb/store/copr", + visibility = ["//visibility:public"], + deps = [ + "//config", + "//domain/infosync", + "//errno", + "//kv", + "//metrics", + "//parser/terror", + "//store/driver/backoff", + "//store/driver/error", + "//store/driver/options", + "//util", + "//util/execdetails", + "//util/logutil", + "//util/mathutil", + "//util/memory", + "//util/paging", + "//util/trxevents", + "@com_github_dgraph_io_ristretto//:ristretto", + "@com_github_gogo_protobuf//proto", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/coprocessor", + "@com_github_pingcap_kvproto//pkg/kvrpcpb", + "@com_github_pingcap_kvproto//pkg/metapb", + "@com_github_pingcap_kvproto//pkg/mpp", + "@com_github_pingcap_log//:log", + "@com_github_pingcap_tipb//go-tipb", + "@com_github_stathat_consistent//:consistent", + "@com_github_tikv_client_go_v2//config", + "@com_github_tikv_client_go_v2//error", + "@com_github_tikv_client_go_v2//metrics", + "@com_github_tikv_client_go_v2//tikv", + "@com_github_tikv_client_go_v2//tikvrpc", + "@com_github_tikv_client_go_v2//txnkv/txnlock", + "@com_github_tikv_client_go_v2//txnkv/txnsnapshot", + "@com_github_tikv_client_go_v2//util", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//status", + "@org_golang_x_exp//slices", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "copr_test", + timeout = "short", + srcs = [ + "batch_coprocessor_test.go", + "coprocessor_cache_test.go", + "coprocessor_test.go", + "key_ranges_test.go", + "main_test.go", + "mpp_probe_test.go", + ], + embed = [":copr"], + flaky = True, + race = "on", + deps = [ + "//kv", + "//store/driver/backoff", + "//testkit/testsetup", + "//util/paging", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/coprocessor", + "@com_github_pingcap_kvproto//pkg/mpp", + "@com_github_stathat_consistent//:consistent", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//config", + "@com_github_tikv_client_go_v2//testutils", + "@com_github_tikv_client_go_v2//tikv", + "@com_github_tikv_client_go_v2//tikvrpc", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/store/copr/batch_coprocessor.go b/store/copr/batch_coprocessor.go index dae965d912900..75f0dbc8b0a34 100644 --- a/store/copr/batch_coprocessor.go +++ b/store/copr/batch_coprocessor.go @@ -30,8 +30,11 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/kvrpcpb" +<<<<<<< HEAD "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/mpp" +======= +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) "github.com/pingcap/log" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/driver/backoff" @@ -290,12 +293,15 @@ func balanceBatchCopTaskWithContinuity(storeTaskMap map[uint64]*batchCopTask, ca // // The second balance strategy: Not only consider the region count between TiFlash stores, but also try to make the regions' range continuous(stored in TiFlash closely). // If balanceWithContinuity is true, the second balance strategy is enable. +<<<<<<< HEAD func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks []*batchCopTask, mppStoreLastFailTime map[string]time.Time, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64) []*batchCopTask { +======= +func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks []*batchCopTask, isMPP bool, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64) []*batchCopTask { +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) if len(originalTasks) == 0 { log.Info("Batch cop task balancer got an empty task set.") return originalTasks } - isMPP := mppStoreLastFailTime != nil // for mpp, we still need to detect the store availability if len(originalTasks) <= 1 && !isMPP { return originalTasks @@ -326,12 +332,12 @@ func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks [] var wg sync.WaitGroup var mu sync.Mutex wg.Add(len(stores)) - cur := time.Now() for i := range stores { go func(idx int) { defer wg.Done() s := stores[idx] +<<<<<<< HEAD var last time.Time var ok bool mu.Lock() @@ -363,6 +369,18 @@ func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks [] if cur.Sub(last) < ttl { logutil.BgLogger().Warn("Cannot detect store's availability because the current time has not reached MPPStoreLastFailTime + MPPStoreFailTTL", zap.String("store address", s.GetAddr()), zap.Time("last fail time", last)) +======= + // check if store is failed already. + ok := GlobalMPPFailedStoreProber.IsRecovery(ctx, s.GetAddr(), ttl) + if !ok { + return + } + + tikvClient := kvStore.GetTiKVClient() + ok = detectMPPStore(ctx, tikvClient, s.GetAddr(), DetectTimeoutLimit) + if !ok { + GlobalMPPFailedStoreProber.Add(ctx, s.GetAddr(), tikvClient) +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) return } @@ -523,7 +541,95 @@ func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks [] return ret } +<<<<<<< HEAD func buildBatchCopTasks(bo *backoff.Backoffer, store *kvStore, ranges *KeyRanges, storeType kv.StoreType, mppStoreLastFailTime map[string]time.Time, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64) ([]*batchCopTask, error) { +======= +func buildBatchCopTasksForNonPartitionedTable(bo *backoff.Backoffer, + store *kvStore, + ranges *KeyRanges, + storeType kv.StoreType, + isMPP bool, + ttl time.Duration, + balanceWithContinuity bool, + balanceContinuousRegionCount int64) ([]*batchCopTask, error) { + if config.GetGlobalConfig().DisaggregatedTiFlash { + return buildBatchCopTasksConsistentHash(bo, store, []*KeyRanges{ranges}, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount) + } + return buildBatchCopTasksCore(bo, store, []*KeyRanges{ranges}, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount) +} + +func buildBatchCopTasksForPartitionedTable(bo *backoff.Backoffer, + store *kvStore, + rangesForEachPhysicalTable []*KeyRanges, + storeType kv.StoreType, + isMPP bool, + ttl time.Duration, + balanceWithContinuity bool, + balanceContinuousRegionCount int64, + partitionIDs []int64) (batchTasks []*batchCopTask, err error) { + if config.GetGlobalConfig().DisaggregatedTiFlash { + batchTasks, err = buildBatchCopTasksConsistentHash(bo, store, rangesForEachPhysicalTable, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount) + } else { + batchTasks, err = buildBatchCopTasksCore(bo, store, rangesForEachPhysicalTable, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount) + } + if err != nil { + return nil, err + } + // generate tableRegions for batchCopTasks + convertRegionInfosToPartitionTableRegions(batchTasks, partitionIDs) + return batchTasks, nil +} + +func buildBatchCopTasksConsistentHash(bo *backoff.Backoffer, store *kvStore, rangesForEachPhysicalTable []*KeyRanges, storeType kv.StoreType, isMPP bool, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64) ([]*batchCopTask, error) { + batchTasks, err := buildBatchCopTasksCore(bo, store, rangesForEachPhysicalTable, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount) + if err != nil { + return nil, err + } + cache := store.GetRegionCache() + stores, err := cache.GetTiFlashComputeStores(bo.TiKVBackoffer()) + if err != nil { + return nil, err + } + if len(stores) == 0 { + return nil, errors.New("No available tiflash_compute node") + } + + hasher := consistent.New() + for _, store := range stores { + hasher.Add(store.GetAddr()) + } + for _, task := range batchTasks { + addr, err := hasher.Get(task.storeAddr) + if err != nil { + return nil, err + } + var store *tikv.Store + for _, s := range stores { + if s.GetAddr() == addr { + store = s + break + } + } + if store == nil { + return nil, errors.New("cannot find tiflash_compute store: " + addr) + } + + task.storeAddr = addr + task.ctx.Store = store + task.ctx.Addr = addr + } + logutil.BgLogger().Info("build batchCop tasks for disaggregated tiflash using ConsistentHash done.", zap.Int("len(tasks)", len(batchTasks))) + for _, task := range batchTasks { + logutil.BgLogger().Debug("batchTasks detailed info", zap.String("addr", task.storeAddr), zap.Int("RegionInfo number", len(task.regionInfos))) + } + return batchTasks, nil +} + +// When `partitionIDs != nil`, it means that buildBatchCopTasksCore is constructing a batch cop tasks for PartitionTableScan. +// At this time, `len(rangesForEachPhysicalTable) == len(partitionIDs)` and `rangesForEachPhysicalTable[i]` is for partition `partitionIDs[i]`. +// Otherwise, `rangesForEachPhysicalTable[0]` indicates the range for the single physical table. +func buildBatchCopTasksCore(bo *backoff.Backoffer, store *kvStore, rangesForEachPhysicalTable []*KeyRanges, storeType kv.StoreType, isMPP bool, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64) ([]*batchCopTask, error) { +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) cache := store.GetRegionCache() start := time.Now() const cmdType = tikvrpc.CmdBatchCop @@ -598,7 +704,7 @@ func buildBatchCopTasks(bo *backoff.Backoffer, store *kvStore, ranges *KeyRanges logutil.BgLogger().Debug(msg) } balanceStart := time.Now() - batchTasks = balanceBatchCopTask(bo.GetCtx(), store, batchTasks, mppStoreLastFailTime, ttl, balanceWithContinuity, balanceContinuousRegionCount) + batchTasks = balanceBatchCopTask(bo.GetCtx(), store, batchTasks, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount) balanceElapsed := time.Since(balanceStart) if log.GetLevel() <= zap.DebugLevel { msg := "After region balance:" @@ -626,8 +732,29 @@ func (c *CopClient) sendBatch(ctx context.Context, req *kv.Request, vars *tikv.V } ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTs) bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, vars) +<<<<<<< HEAD ranges := NewKeyRanges(req.KeyRanges) tasks, err := buildBatchCopTasks(bo, c.store.kvStore, ranges, req.StoreType, nil, 0, false, 0) +======= + + var tasks []*batchCopTask + var err error + if req.PartitionIDAndRanges != nil { + // For Partition Table Scan + keyRanges := make([]*KeyRanges, 0, len(req.PartitionIDAndRanges)) + partitionIDs := make([]int64, 0, len(req.PartitionIDAndRanges)) + for _, pi := range req.PartitionIDAndRanges { + keyRanges = append(keyRanges, NewKeyRanges(pi.KeyRanges)) + partitionIDs = append(partitionIDs, pi.ID) + } + tasks, err = buildBatchCopTasksForPartitionedTable(bo, c.store.kvStore, keyRanges, req.StoreType, false, 0, false, 0, partitionIDs) + } else { + // TODO: merge the if branch. + ranges := NewKeyRanges(req.KeyRanges.FirstPartitionRange()) + tasks, err = buildBatchCopTasksForNonPartitionedTable(bo, c.store.kvStore, ranges, req.StoreType, false, 0, false, 0) + } + +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) if err != nil { return copErrorResponse{err} } @@ -762,6 +889,7 @@ func (b *batchCopIterator) handleTask(ctx context.Context, bo *Backoffer, task * // Merge all ranges and request again. func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *backoff.Backoffer, batchTask *batchCopTask) ([]*batchCopTask, error) { +<<<<<<< HEAD var ranges []kv.KeyRange for _, ri := range batchTask.regionInfos { ri.Ranges.Do(func(ran *kv.KeyRange) { @@ -769,6 +897,36 @@ func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *backoff.Ba }) } return buildBatchCopTasks(bo, b.store, NewKeyRanges(ranges), b.req.StoreType, nil, 0, false, 0) +======= + if batchTask.regionInfos != nil { + var ranges []kv.KeyRange + for _, ri := range batchTask.regionInfos { + ri.Ranges.Do(func(ran *kv.KeyRange) { + ranges = append(ranges, *ran) + }) + } + ret, err := buildBatchCopTasksForNonPartitionedTable(bo, b.store, NewKeyRanges(ranges), b.req.StoreType, false, 0, false, 0) + return ret, err + } + // Retry Partition Table Scan + keyRanges := make([]*KeyRanges, 0, len(batchTask.PartitionTableRegions)) + pid := make([]int64, 0, len(batchTask.PartitionTableRegions)) + for _, trs := range batchTask.PartitionTableRegions { + pid = append(pid, trs.PhysicalTableId) + ranges := make([]kv.KeyRange, 0, len(trs.Regions)) + for _, ri := range trs.Regions { + for _, ran := range ri.Ranges { + ranges = append(ranges, kv.KeyRange{ + StartKey: ran.Start, + EndKey: ran.End, + }) + } + } + keyRanges = append(keyRanges, NewKeyRanges(ranges)) + } + ret, err := buildBatchCopTasksForPartitionedTable(bo, b.store, keyRanges, b.req.StoreType, false, 0, false, 0, pid) + return ret, err +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) } const readTimeoutUltraLong = 3600 * time.Second // For requests that may scan many regions for tiflash. diff --git a/store/copr/batch_coprocessor_test.go b/store/copr/batch_coprocessor_test.go index 6a8f503d76aa5..4bd12990b5525 100644 --- a/store/copr/batch_coprocessor_test.go +++ b/store/copr/batch_coprocessor_test.go @@ -119,13 +119,13 @@ func TestBalanceBatchCopTaskWithContinuity(t *testing.T) { func TestBalanceBatchCopTaskWithEmptyTaskSet(t *testing.T) { { var nilTaskSet []*batchCopTask - nilResult := balanceBatchCopTask(nil, nil, nilTaskSet, nil, time.Second, false, 0) + nilResult := balanceBatchCopTask(nil, nil, nilTaskSet, false, time.Second, false, 0) require.True(t, nilResult == nil) } { emptyTaskSet := make([]*batchCopTask, 0) - emptyResult := balanceBatchCopTask(nil, nil, emptyTaskSet, nil, time.Second, false, 0) + emptyResult := balanceBatchCopTask(nil, nil, emptyTaskSet, false, time.Second, false, 0) require.True(t, emptyResult != nil) require.True(t, len(emptyResult) == 0) } diff --git a/store/copr/mpp.go b/store/copr/mpp.go index 39a914c27c5b0..aa6eff058239e 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -61,11 +61,33 @@ func (c *MPPClient) selectAllTiFlashStore() []kv.MPPTaskMeta { } // ConstructMPPTasks receives ScheduleRequest, which are actually collects of kv ranges. We allocates MPPTaskMeta for them and returns. +<<<<<<< HEAD func (c *MPPClient) ConstructMPPTasks(ctx context.Context, req *kv.MPPBuildTasksRequest, mppStoreLastFailTime map[string]time.Time, ttl time.Duration) ([]kv.MPPTaskMeta, error) { ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTS) bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, nil) if req.KeyRanges == nil { return c.selectAllTiFlashStore(), nil +======= +func (c *MPPClient) ConstructMPPTasks(ctx context.Context, req *kv.MPPBuildTasksRequest, ttl time.Duration) ([]kv.MPPTaskMeta, error) { + ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTS) + bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, nil) + var tasks []*batchCopTask + var err error + if req.PartitionIDAndRanges != nil { + rangesForEachPartition := make([]*KeyRanges, len(req.PartitionIDAndRanges)) + partitionIDs := make([]int64, len(req.PartitionIDAndRanges)) + for i, p := range req.PartitionIDAndRanges { + rangesForEachPartition[i] = NewKeyRanges(p.KeyRanges) + partitionIDs[i] = p.ID + } + tasks, err = buildBatchCopTasksForPartitionedTable(bo, c.store, rangesForEachPartition, kv.TiFlash, true, ttl, true, 20, partitionIDs) + } else { + if req.KeyRanges == nil { + return c.selectAllTiFlashStore(), nil + } + ranges := NewKeyRanges(req.KeyRanges) + tasks, err = buildBatchCopTasksForNonPartitionedTable(bo, c.store, ranges, kv.TiFlash, true, ttl, true, 20) +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) } ranges := NewKeyRanges(req.KeyRanges) tasks, err := buildBatchCopTasks(bo, c.store, ranges, kv.TiFlash, mppStoreLastFailTime, ttl, true, 20) diff --git a/store/copr/mpp_probe.go b/store/copr/mpp_probe.go new file mode 100644 index 0000000000000..0a0eba286648e --- /dev/null +++ b/store/copr/mpp_probe.go @@ -0,0 +1,270 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package copr + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/mpp" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/util/logutil" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "go.uber.org/zap" +) + +// GlobalMPPFailedStoreProber mpp failed store probe +var GlobalMPPFailedStoreProber *MPPFailedStoreProber + +const ( + // DetectPeriod detect period + DetectPeriod = 3 * time.Second + // DetectTimeoutLimit detect timeout + DetectTimeoutLimit = 2 * time.Second + // MaxRecoveryTimeLimit wait TiFlash recovery,more than MPPStoreFailTTL + MaxRecoveryTimeLimit = 15 * time.Minute + // MaxObsoletTimeLimit no request for a long time,that might be obsoleted + MaxObsoletTimeLimit = time.Hour +) + +// MPPStoreState the state for MPPStore. +type MPPStoreState struct { + address string // MPPStore TiFlash address + tikvClient tikv.Client + + lock struct { + sync.Mutex + + recoveryTime time.Time + lastLookupTime time.Time + lastDetectTime time.Time + } +} + +// MPPFailedStoreProber use for detecting of failed TiFlash instance +type MPPFailedStoreProber struct { + failedMPPStores *sync.Map + lock *sync.Mutex + isStop *atomic.Bool + wg *sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + + detectPeriod time.Duration + detectTimeoutLimit time.Duration + maxRecoveryTimeLimit time.Duration + maxObsoletTimeLimit time.Duration +} + +func (t *MPPStoreState) detect(ctx context.Context, detectPeriod time.Duration, detectTimeoutLimit time.Duration) { + if time.Since(t.lock.lastDetectTime) < detectPeriod { + return + } + + defer func() { t.lock.lastDetectTime = time.Now() }() + metrics.TiFlashFailedMPPStoreState.WithLabelValues(t.address).Set(0) + ok := detectMPPStore(ctx, t.tikvClient, t.address, detectTimeoutLimit) + if !ok { + metrics.TiFlashFailedMPPStoreState.WithLabelValues(t.address).Set(1) + t.lock.recoveryTime = time.Time{} // if detect failed,reset recovery time to zero. + return + } + + // record the time of the first recovery + if t.lock.recoveryTime.IsZero() { + t.lock.recoveryTime = time.Now() + } +} + +func (t *MPPStoreState) isRecovery(ctx context.Context, recoveryTTL time.Duration) bool { + if !t.lock.TryLock() { + return false + } + defer t.lock.Unlock() + + t.lock.lastLookupTime = time.Now() + if !t.lock.recoveryTime.IsZero() && time.Since(t.lock.recoveryTime) > recoveryTTL { + return true + } + logutil.Logger(ctx).Debug("Cannot detect store's availability "+ + "because the current time has not recovery or wait mppStoreFailTTL", + zap.String("store address", t.address), + zap.Time("recovery time", t.lock.recoveryTime), + zap.Duration("MPPStoreFailTTL", recoveryTTL)) + return false +} + +func (t MPPFailedStoreProber) scan(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + logutil.Logger(ctx).Warn("mpp failed store probe scan error,will restart", zap.Any("recover", r), zap.Stack("stack")) + } + }() + + do := func(k, v any) { + address := fmt.Sprint(k) + state, ok := v.(*MPPStoreState) + if !ok { + logutil.BgLogger().Warn("MPPStoreState struct assert failed,will be clean", + zap.String("address", address)) + t.Delete(address) + return + } + + if !state.lock.TryLock() { + return + } + defer state.lock.Unlock() + + state.detect(ctx, t.detectPeriod, t.detectTimeoutLimit) + + // clean restored store + if !state.lock.recoveryTime.IsZero() && time.Since(state.lock.recoveryTime) > t.maxRecoveryTimeLimit { + t.Delete(address) + // clean store that may be obsolete + } else if state.lock.recoveryTime.IsZero() && time.Since(state.lock.lastLookupTime) > t.maxObsoletTimeLimit { + t.Delete(address) + } + } + + f := func(k, v any) bool { + go do(k, v) + return true + } + + metrics.TiFlashFailedMPPStoreState.WithLabelValues("probe").Set(-1) //probe heartbeat + t.failedMPPStores.Range(f) +} + +// Add add a store when sync probe failed +func (t *MPPFailedStoreProber) Add(ctx context.Context, address string, tikvClient tikv.Client) { + state := MPPStoreState{ + address: address, + tikvClient: tikvClient, + } + state.lock.lastLookupTime = time.Now() + logutil.Logger(ctx).Debug("add mpp store to failed list", zap.String("address", address)) + t.failedMPPStores.Store(address, &state) +} + +// IsRecovery check whether the store is recovery +func (t *MPPFailedStoreProber) IsRecovery(ctx context.Context, address string, recoveryTTL time.Duration) bool { + logutil.Logger(ctx).Debug("check failed store recovery", + zap.String("address", address), zap.Duration("ttl", recoveryTTL)) + v, ok := t.failedMPPStores.Load(address) + if !ok { + // store not in failed map + return true + } + + state, ok := v.(*MPPStoreState) + if !ok { + logutil.BgLogger().Warn("MPPStoreState struct assert failed,will be clean", + zap.String("address", address)) + t.Delete(address) + return false + } + + return state.isRecovery(ctx, recoveryTTL) +} + +// Run a loop of scan +// there can be only one background task +func (t *MPPFailedStoreProber) Run() { + if !t.lock.TryLock() { + return + } + t.wg.Add(1) + t.isStop.Swap(false) + go func() { + defer t.wg.Done() + defer t.lock.Unlock() + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-t.ctx.Done(): + logutil.BgLogger().Debug("ctx.done") + return + case <-ticker.C: + t.scan(t.ctx) + } + } + }() + logutil.BgLogger().Debug("run a background probe process for mpp") +} + +// Stop stop background goroutine +func (t *MPPFailedStoreProber) Stop() { + if !t.isStop.CompareAndSwap(false, true) { + return + } + t.cancel() + t.wg.Wait() + logutil.BgLogger().Debug("stop background task") +} + +// Delete clean store from failed map +func (t *MPPFailedStoreProber) Delete(address string) { + metrics.TiFlashFailedMPPStoreState.DeleteLabelValues(address) + _, ok := t.failedMPPStores.LoadAndDelete(address) + if !ok { + logutil.BgLogger().Warn("Store is deleted", zap.String("address", address)) + } +} + +// MPPStore detect function +func detectMPPStore(ctx context.Context, client tikv.Client, address string, detectTimeoutLimit time.Duration) bool { + resp, err := client.SendRequest(ctx, address, &tikvrpc.Request{ + Type: tikvrpc.CmdMPPAlive, + StoreTp: tikvrpc.TiFlash, + Req: &mpp.IsAliveRequest{}, + Context: kvrpcpb.Context{}, + }, detectTimeoutLimit) + if err != nil || !resp.Resp.(*mpp.IsAliveResponse).Available { + if err == nil { + err = fmt.Errorf("store not ready to serve") + } + logutil.BgLogger().Warn("Store is not ready", + zap.String("store address", address), + zap.String("err message", err.Error())) + return false + } + return true +} + +func init() { + ctx, cancel := context.WithCancel(context.Background()) + isStop := atomic.Bool{} + isStop.Swap(true) + GlobalMPPFailedStoreProber = &MPPFailedStoreProber{ + failedMPPStores: &sync.Map{}, + lock: &sync.Mutex{}, + isStop: &isStop, + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + detectPeriod: DetectPeriod, + detectTimeoutLimit: DetectTimeoutLimit, + maxRecoveryTimeLimit: MaxRecoveryTimeLimit, + maxObsoletTimeLimit: MaxObsoletTimeLimit, + } +} diff --git a/store/copr/mpp_probe_test.go b/store/copr/mpp_probe_test.go new file mode 100644 index 0000000000000..7826c970d3e1e --- /dev/null +++ b/store/copr/mpp_probe_test.go @@ -0,0 +1,177 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package copr + +import ( + "context" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/mpp" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/tikvrpc" +) + +const ( + testimeout = "timeout" + Error = "error" + Normal = "normal" +) + +type mockDetectClient struct { + errortestype string +} + +func (t *mockDetectClient) CloseAddr(string) error { + return nil +} + +func (t *mockDetectClient) Close() error { + return nil +} + +func (t *mockDetectClient) SendRequest( + ctx context.Context, + addr string, + req *tikvrpc.Request, + timeout time.Duration, +) (*tikvrpc.Response, error) { + if t.errortestype == Error { + return nil, errors.New("store error") + } else if t.errortestype == testimeout { + return &tikvrpc.Response{Resp: &mpp.IsAliveResponse{}}, nil + } + + return &tikvrpc.Response{Resp: &mpp.IsAliveResponse{Available: true}}, nil +} + +type ProbeTest map[string]*mockDetectClient + +func (t ProbeTest) add(ctx context.Context) { + for k, v := range t { + GlobalMPPFailedStoreProber.Add(ctx, k, v) + } +} + +func (t ProbeTest) reSetErrortestype(to string) { + for k, v := range t { + if to == Normal { + v.errortestype = Normal + } else { + v.errortestype = k + } + } +} + +func (t ProbeTest) judge(ctx context.Context, test *testing.T, recoveryTTL time.Duration, need bool) { + for k := range t { + ok := GlobalMPPFailedStoreProber.IsRecovery(ctx, k, recoveryTTL) + require.Equal(test, need, ok) + } +} + +func failedStoreSizeJudge(ctx context.Context, test *testing.T, need int) { + var l int + GlobalMPPFailedStoreProber.scan(ctx) + time.Sleep(time.Second / 10) + GlobalMPPFailedStoreProber.failedMPPStores.Range(func(k, v interface{}) bool { + l++ + return true + }) + require.Equal(test, need, l) +} + +func testFlow(ctx context.Context, probetestest ProbeTest, test *testing.T, flow []string) { + probetestest.add(ctx) + for _, to := range flow { + probetestest.reSetErrortestype(to) + + GlobalMPPFailedStoreProber.scan(ctx) + time.Sleep(time.Second / 10) //wait detect goroutine finish + + var need bool + if to == Normal { + need = true + } + probetestest.judge(ctx, test, 0, need) + probetestest.judge(ctx, test, time.Minute, false) + } + + lastTo := flow[len(flow)-1] + cleanRecover := func(need int) { + GlobalMPPFailedStoreProber.maxRecoveryTimeLimit = 0 - time.Second + failedStoreSizeJudge(ctx, test, need) + GlobalMPPFailedStoreProber.maxRecoveryTimeLimit = MaxRecoveryTimeLimit + } + + cleanObsolet := func(need int) { + GlobalMPPFailedStoreProber.maxObsoletTimeLimit = 0 - time.Second + failedStoreSizeJudge(ctx, test, need) + GlobalMPPFailedStoreProber.maxObsoletTimeLimit = MaxObsoletTimeLimit + } + + if lastTo == Error { + cleanRecover(2) + cleanObsolet(0) + } else if lastTo == Normal { + cleanObsolet(2) + cleanRecover(0) + } +} + +func TestMPPFailedStoreProbe(t *testing.T) { + ctx := context.Background() + + notExistAddress := "not exist address" + + GlobalMPPFailedStoreProber.detectPeriod = 0 - time.Second + + // check not exist address + ok := GlobalMPPFailedStoreProber.IsRecovery(ctx, notExistAddress, 0) + require.True(t, ok) + + GlobalMPPFailedStoreProber.scan(ctx) + + probetestest := map[string]*mockDetectClient{ + testimeout: {errortestype: testimeout}, + Error: {errortestype: Error}, + } + + testFlowFinallyRecover := []string{Error, Normal, Error, Error, Normal} + testFlow(ctx, probetestest, t, testFlowFinallyRecover) + testFlowFinallyDesert := []string{Error, Normal, Normal, Error, Error} + testFlow(ctx, probetestest, t, testFlowFinallyDesert) +} + +func TestMPPFailedStoreProbeGoroutineTask(t *testing.T) { + // Confirm that multiple tasks are not allowed + GlobalMPPFailedStoreProber.lock.Lock() + GlobalMPPFailedStoreProber.Run() + GlobalMPPFailedStoreProber.lock.Unlock() + + GlobalMPPFailedStoreProber.Run() + GlobalMPPFailedStoreProber.Stop() +} + +func TestMPPFailedStoreAssertFailed(t *testing.T) { + ctx := context.Background() + + GlobalMPPFailedStoreProber.failedMPPStores.Store("errorinfo", nil) + GlobalMPPFailedStoreProber.scan(ctx) + + GlobalMPPFailedStoreProber.failedMPPStores.Store("errorinfo", nil) + GlobalMPPFailedStoreProber.IsRecovery(ctx, "errorinfo", 0) +} diff --git a/store/copr/store.go b/store/copr/store.go index 1783ee294f8e1..362af4d4055e9 100644 --- a/store/copr/store.go +++ b/store/copr/store.go @@ -78,6 +78,7 @@ func NewStore(s *tikv.KVStore, coprCacheConfig *config.CoprocessorCache) (*Store if err != nil { return nil, errors.Trace(err) } + /* #nosec G404 */ return &Store{ kvStore: &kvStore{store: s}, diff --git a/tests/realtikvtest/sessiontest/session_fail_test.go b/tests/realtikvtest/sessiontest/session_fail_test.go new file mode 100644 index 0000000000000..a3df51be821c0 --- /dev/null +++ b/tests/realtikvtest/sessiontest/session_fail_test.go @@ -0,0 +1,205 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessiontest + +import ( + "context" + "testing" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/tests/realtikvtest" + "github.com/stretchr/testify/require" +) + +func TestFailStatementCommitInRetry(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("create table t (id int)") + + tk.MustExec("begin") + tk.MustExec("insert into t values (1)") + tk.MustExec("insert into t values (2),(3),(4),(5)") + tk.MustExec("insert into t values (6)") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/session/mockCommitError8942", `return(true)`)) + _, err := tk.Exec("commit") + require.Error(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/session/mockCommitError8942")) + + tk.MustExec("insert into t values (6)") + tk.MustQuery(`select * from t`).Check(testkit.Rows("6")) +} + +func TestGetTSFailDirtyState(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("create table t (id int)") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/session/mockGetTSFail", "return")) + ctx := failpoint.WithHook(context.Background(), func(ctx context.Context, fpname string) bool { + return fpname == "github.com/pingcap/tidb/session/mockGetTSFail" + }) + _, err := tk.Session().Execute(ctx, "select * from t") + if config.GetGlobalConfig().Store == "unistore" { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + // Fix a bug that active txn fail set TxnState.fail to error, and then the following write + // affected by this fail flag. + tk.MustExec("insert into t values (1)") + tk.MustQuery(`select * from t`).Check(testkit.Rows("1")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/session/mockGetTSFail")) +} + +func TestGetTSFailDirtyStateInretry(t *testing.T) { + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/session/mockCommitError")) + require.NoError(t, failpoint.Disable("tikvclient/mockGetTSErrorInRetry")) + }() + + store := realtikvtest.CreateMockStoreAndSetup(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("create table t (id int)") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/session/mockCommitError", `return(true)`)) + // This test will mock a PD timeout error, and recover then. + // Just make mockGetTSErrorInRetry return true once, and then return false. + require.NoError(t, failpoint.Enable("tikvclient/mockGetTSErrorInRetry", + `1*return(true)->return(false)`)) + tk.MustExec("insert into t values (2)") + tk.MustQuery(`select * from t`).Check(testkit.Rows("2")) +} + +func TestKillFlagInBackoff(t *testing.T) { + // This test checks the `killed` flag is passed down to the backoffer through + // session.KVVars. + store := realtikvtest.CreateMockStoreAndSetup(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("create table kill_backoff (id int)") + // Inject 1 time timeout. If `Killed` is not successfully passed, it will retry and complete query. + require.NoError(t, failpoint.Enable("tikvclient/tikvStoreSendReqResult", `return("timeout")->return("")`)) + defer failpoint.Disable("tikvclient/tikvStoreSendReqResult") + // Set kill flag and check its passed to backoffer. + tk.Session().GetSessionVars().Killed = 1 + rs, err := tk.Exec("select * from kill_backoff") + require.NoError(t, err) + _, err = session.ResultSetToStringSlice(context.TODO(), tk.Session(), rs) + // `interrupted` is returned when `Killed` is set. + require.Regexp(t, ".*Query execution was interrupted.*", err.Error()) + rs.Close() +} + +func TestClusterTableSendError(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + require.NoError(t, failpoint.Enable("tikvclient/tikvStoreSendReqResult", `return("requestTiDBStoreError")`)) + defer func() { require.NoError(t, failpoint.Disable("tikvclient/tikvStoreSendReqResult")) }() + tk.MustQuery("select * from information_schema.cluster_slow_query") + require.Equal(t, tk.Session().GetSessionVars().StmtCtx.WarningCount(), uint16(1)) + require.Regexp(t, ".*TiDB server timeout, address is.*", tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err.Error()) +} + +func TestAutoCommitNeedNotLinearizability(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("drop table if exists t1;") + defer tk.MustExec("drop table if exists t1") + tk.MustExec(`create table t1 (c int)`) + + require.NoError(t, failpoint.Enable("tikvclient/getMinCommitTSFromTSO", `panic`)) + defer func() { require.NoError(t, failpoint.Disable("tikvclient/getMinCommitTSFromTSO")) }() + + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("tidb_enable_async_commit", "1")) + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("tidb_guarantee_linearizability", "1")) + + // Auto-commit transactions don't need to get minCommitTS from TSO + tk.MustExec("INSERT INTO t1 VALUES (1)") + + tk.MustExec("BEGIN") + tk.MustExec("INSERT INTO t1 VALUES (2)") + // An explicit transaction needs to get minCommitTS from TSO + func() { + defer func() { + err := recover() + require.NotNil(t, err) + }() + tk.MustExec("COMMIT") + }() + + tk.MustExec("set autocommit = 0") + tk.MustExec("INSERT INTO t1 VALUES (3)") + func() { + defer func() { + err := recover() + require.NotNil(t, err) + }() + tk.MustExec("COMMIT") + }() + + // Same for 1PC + tk.MustExec("set autocommit = 1") + require.NoError(t, tk.Session().GetSessionVars().SetSystemVar("tidb_enable_1pc", "1")) + tk.MustExec("INSERT INTO t1 VALUES (4)") + + tk.MustExec("BEGIN") + tk.MustExec("INSERT INTO t1 VALUES (5)") + func() { + defer func() { + err := recover() + require.NotNil(t, err) + }() + tk.MustExec("COMMIT") + }() + + tk.MustExec("set autocommit = 0") + tk.MustExec("INSERT INTO t1 VALUES (6)") + func() { + defer func() { + err := recover() + require.NotNil(t, err) + }() + tk.MustExec("COMMIT") + }() +} + +func TestKill(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("kill connection_id();") +} diff --git a/tidb-server/BUILD.bazel b/tidb-server/BUILD.bazel new file mode 100644 index 0000000000000..361a929351642 --- /dev/null +++ b/tidb-server/BUILD.bazel @@ -0,0 +1,107 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test") + +go_library( + name = "tidb-server_lib", + srcs = ["main.go"], + importpath = "github.com/pingcap/tidb/tidb-server", + visibility = ["//visibility:private"], + deps = [ + "//bindinfo", + "//config", + "//ddl", + "//domain", + "//domain/infosync", + "//executor", + "//extension", + "//kv", + "//metrics", + "//parser/mysql", + "//parser/terror", + "//parser/types", + "//planner/core", + "//plugin", + "//privilege/privileges", + "//resourcemanager:resourcemanage", + "//server", + "//session", + "//session/txninfo", + "//sessionctx/binloginfo", + "//sessionctx/variable", + "//statistics", + "//store", + "//store/copr", + "//store/driver", + "//store/mockstore", + "//store/mockstore/unistore/metrics", + "//tidb-binlog/pump_client", + "//util", + "//util/chunk", + "//util/cpuprofile", + "//util/deadlockhistory", + "//util/disk", + "//util/domainutil", + "//util/kvcache", + "//util/logutil", + "//util/memory", + "//util/printer", + "//util/sem", + "//util/signal", + "//util/sys/linux", + "//util/sys/storage", + "//util/systimemon", + "//util/topsql", + "//util/versioninfo", + "@com_github_coreos_go_semver//semver", + "@com_github_opentracing_opentracing_go//:opentracing-go", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_log//:log", + "@com_github_prometheus_client_golang//prometheus", + "@com_github_prometheus_client_golang//prometheus/push", + "@com_github_tikv_client_go_v2//tikv", + "@com_github_tikv_client_go_v2//txnkv/transaction", + "@com_github_tikv_pd_client//:client", + "@org_uber_go_automaxprocs//maxprocs", + "@org_uber_go_zap//:zap", + ], +) + +go_binary( + name = "tidb-server", + embed = [":tidb-server_lib"], + visibility = ["//visibility:public"], + x_defs = { + "github.com/pingcap/tidb/parser/mysql.TiDBReleaseVersion": "{STABLE_TiDB_RELEASE_VERSION}", + "github.com/pingcap/tidb/util/versioninfo.TiDBBuildTS": "{STABLE_TiDB_BUILD_UTCTIME}", + "github.com/pingcap/tidb/util/versioninfo.TiDBGitHash": "{STABLE_TIDB_GIT_HASH}", + "github.com/pingcap/tidb/util/versioninfo.TiDBGitBranch": "{STABLE_TIDB_GIT_BRANCH}", + "github.com/pingcap/tidb/util/versioninfo.TiDBEdition": "{STABLE_TIDB_EDITION}", + }, +) + +go_binary( + name = "tidb-server-check", + embed = [":tidb-server_lib"], + gc_linkopts = [ + "-X", + "github.com/pingcap/tidb/config.checkBeforeDropLDFlag=1", + ], + visibility = ["//visibility:public"], +) + +go_test( + name = "tidb-server_test", + timeout = "short", + srcs = ["main_test.go"], + embed = [":tidb-server_lib"], + flaky = True, + deps = [ + "//config", + "//parser/mysql", + "//sessionctx/variable", + "//testkit/testsetup", + "@com_github_stretchr_testify//require", + "@io_opencensus_go//stats/view", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/tidb-server/main.go b/tidb-server/main.go index ba35f263e960b..30bbdc609e4b4 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -50,6 +50,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" kvstore "github.com/pingcap/tidb/store" + "github.com/pingcap/tidb/store/copr" "github.com/pingcap/tidb/store/driver" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util" @@ -292,6 +293,12 @@ func createStoreAndDomain() (kv.Storage, *domain.Domain) { var err error storage, err := kvstore.New(fullPath) terror.MustNil(err) +<<<<<<< HEAD +======= + copr.GlobalMPPFailedStoreProber.Run() + err = infosync.CheckTiKVVersion(storage, *semver.New(versioninfo.TiKVMinVersion)) + terror.MustNil(err) +>>>>>>> aeccf77637 (*: optimize mpp probe (#39932)) // Bootstrap a session to load information schema. dom, err := session.BootstrapSession(storage) terror.MustNil(err) @@ -706,6 +713,7 @@ func setupTracing() { func closeDomainAndStorage(storage kv.Storage, dom *domain.Domain) { tikv.StoreShuttingDown(1) dom.Close() + copr.GlobalMPPFailedStoreProber.Stop() err := storage.Close() terror.Log(errors.Trace(err)) }