From c30d34f07f925a7914600bb60373581f93a18dae Mon Sep 17 00:00:00 2001 From: dongjunduo Date: Tue, 14 Dec 2021 05:08:36 -0600 Subject: [PATCH 1/8] planner: Introduce a new global variable to control the historical statistics feature (#30646) --- executor/set_test.go | 6 ++++++ sessionctx/variable/sysvar.go | 6 +++++- sessionctx/variable/tidb_vars.go | 2 ++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/executor/set_test.go b/executor/set_test.go index 77ab1b1d26c1a..9be1f1794ce1c 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -587,6 +587,12 @@ func (s *testSerialSuite1) TestSetVar(c *C) { tk.MustExec("set global tidb_enable_tso_follower_proxy = 0") tk.MustQuery("select @@tidb_enable_tso_follower_proxy").Check(testkit.Rows("0")) c.Assert(tk.ExecToErr("set tidb_enable_tso_follower_proxy = 1"), NotNil) + + tk.MustQuery("select @@tidb_enable_historical_stats").Check(testkit.Rows("0")) + tk.MustExec("set global tidb_enable_historical_stats = 1") + tk.MustQuery("select @@tidb_enable_historical_stats").Check(testkit.Rows("1")) + tk.MustExec("set global tidb_enable_historical_stats = 0") + tk.MustQuery("select @@tidb_enable_historical_stats").Check(testkit.Rows("0")) } func (s *testSuite5) TestTruncateIncorrectIntSessionVar(c *C) { diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index ebd1eed637ddd..60543c00d334e 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -1205,7 +1205,11 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeSession, Name: PluginDir, Value: "/data/deploy/plugin", GetSession: func(s *SessionVars) (string, error) { return config.GetGlobalConfig().Plugin.Dir, nil }}, - + {Scope: ScopeGlobal, Name: TiDBEnableHistoricalStats, Value: Off, Type: TypeBool, GetGlobal: func(s *SessionVars) (string, error) { + return getTiDBTableValue(s, "tidb_enable_historical_stats", Off) + }, SetGlobal: func(s *SessionVars, val string) error { + return setTiDBTableValue(s, "tidb_enable_historical_stats", val, "Current historical statistics enable status") + }}, /* tikv gc metrics */ {Scope: ScopeGlobal, Name: TiDBGCEnable, Value: On, Type: TypeBool, GetGlobal: func(s *SessionVars) (string, error) { return getTiDBTableValue(s, "tikv_gc_enable", On) diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 518b88e45de80..ee01348a76441 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -620,6 +620,8 @@ const ( TiDBGCScanLockMode = "tidb_gc_scan_lock_mode" // TiDBEnableEnhancedSecurity restricts SUPER users from certain operations. TiDBEnableEnhancedSecurity = "tidb_enable_enhanced_security" + // TiDBEnableHistoricalStats enables the historical statistics feature (default off) + TiDBEnableHistoricalStats = "tidb_enable_historical_stats" ) // TiDB intentional limits From e9b1fb8ab57afa2e5ca71a09c1ba004404b75a2b Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Tue, 14 Dec 2021 19:22:34 +0800 Subject: [PATCH 2/8] topsql: introduce datasink interface (#30662) --- server/tidb_test.go | 2 +- util/topsql/reporter/datasink.go | 36 ++++ util/topsql/reporter/reporter.go | 40 ++-- util/topsql/reporter/reporter_test.go | 2 +- .../reporter/{client.go => single_target.go} | 177 +++++++++++++----- util/topsql/topsql.go | 4 +- util/topsql/topsql_test.go | 2 +- 7 files changed, 188 insertions(+), 75 deletions(-) create mode 100644 util/topsql/reporter/datasink.go rename util/topsql/reporter/{client.go => single_target.go} (55%) diff --git a/server/tidb_test.go b/server/tidb_test.go index 55c830eb23325..5c0c2e9e189e1 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1536,7 +1536,7 @@ func TestTopSQLAgent(t *testing.T) { dbt.MustExec("set @@global.tidb_top_sql_report_interval_seconds=2;") dbt.MustExec("set @@global.tidb_top_sql_max_statement_count=5;") - r := reporter.NewRemoteTopSQLReporter(reporter.NewGRPCReportClient(plancodec.DecodeNormalizedPlan)) + r := reporter.NewRemoteTopSQLReporter(reporter.NewSingleTargetDataSink(plancodec.DecodeNormalizedPlan)) tracecpu.GlobalSQLCPUProfiler.SetCollector(&collectorWrapper{r}) // TODO: change to ensure that the right sql statements are reported, not just counts diff --git a/util/topsql/reporter/datasink.go b/util/topsql/reporter/datasink.go new file mode 100644 index 0000000000000..c4206c71dc440 --- /dev/null +++ b/util/topsql/reporter/datasink.go @@ -0,0 +1,36 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reporter + +import "time" + +// DataSink collects and sends data to a target. +type DataSink interface { + // TrySend pushes a report data into the sink, which will later be sent to a target by the sink. A deadline can be + // specified to control how late it should be sent. If the sink is kept full and cannot schedule a send within + // the specified deadline, or the sink is closed, an error will be returned. + TrySend(data ReportData, deadline time.Time) error + + // IsPaused indicates that the DataSink is not expecting to receive records for now + // and may resume in the future. + IsPaused() bool + + // IsDown indicates that the DataSink has been down and can be cleared. + // Note that: once a DataSink is down, it cannot go back to be up. + IsDown() bool + + // Close cleans up resources owned by this DataSink + Close() +} diff --git a/util/topsql/reporter/reporter.go b/util/topsql/reporter/reporter.go index 113cb5ce29bc4..39e3be5eae8e4 100644 --- a/util/topsql/reporter/reporter.go +++ b/util/topsql/reporter/reporter.go @@ -23,7 +23,6 @@ import ( "time" "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" @@ -118,9 +117,9 @@ type planBinaryDecodeFunc func(string) (string, error) // RemoteTopSQLReporter implements a TopSQL reporter that sends data to a remote agent // This should be called periodically to collect TopSQL resource usage metrics type RemoteTopSQLReporter struct { - ctx context.Context - cancel context.CancelFunc - client ReportClient + ctx context.Context + cancel context.CancelFunc + dataSink DataSink // normalizedSQLMap is an map, whose keys are SQL digest strings and values are SQLMeta. normalizedSQLMap atomic.Value // sync.Map @@ -145,12 +144,12 @@ type SQLMeta struct { // // planBinaryDecoder is a decoding function which will be called asynchronously to decode the plan binary to string // MaxStatementsNum is the maximum SQL and plan number, which will restrict the memory usage of the internal LFU cache -func NewRemoteTopSQLReporter(client ReportClient) *RemoteTopSQLReporter { +func NewRemoteTopSQLReporter(dataSink DataSink) *RemoteTopSQLReporter { ctx, cancel := context.WithCancel(context.Background()) tsr := &RemoteTopSQLReporter{ ctx: ctx, cancel: cancel, - client: client, + dataSink: dataSink, collectCPUDataChan: make(chan cpuData, 1), reportCollectedDataChan: make(chan collectedData, 1), } @@ -238,7 +237,7 @@ func (tsr *RemoteTopSQLReporter) Collect(timestamp uint64, records []tracecpu.SQ // Close uses to close and release the reporter resource. func (tsr *RemoteTopSQLReporter) Close() { tsr.cancel() - tsr.client.Close() + tsr.dataSink.Close() } func addEvictedCPUTime(collectTarget map[string]*dataPoints, timestamp uint64, totalCPUTimeMs uint32) { @@ -450,15 +449,15 @@ type collectedData struct { normalizedPlanMap *sync.Map } -// reportData contains data that reporter sends to the agent -type reportData struct { +// ReportData contains data that reporter sends to the agent +type ReportData struct { // collectedData contains the topN collected records and the `others` record which aggregation all records that is out of Top N. collectedData []*dataPoints normalizedSQLMap *sync.Map normalizedPlanMap *sync.Map } -func (d *reportData) hasData() bool { +func (d *ReportData) hasData() bool { if len(d.collectedData) > 0 { return true } @@ -496,9 +495,9 @@ func (tsr *RemoteTopSQLReporter) reportWorker() { } } -// getReportData gets reportData from the collectedData. +// getReportData gets ReportData from the collectedData. // This function will calculate the topN collected records and the `others` record which aggregation all records that is out of Top N. -func (tsr *RemoteTopSQLReporter) getReportData(collected collectedData) reportData { +func (tsr *RemoteTopSQLReporter) getReportData(collected collectedData) ReportData { // Fetch TopN dataPoints. others := collected.records[keyOthers] delete(collected.records, keyOthers) @@ -524,21 +523,20 @@ func (tsr *RemoteTopSQLReporter) getReportData(collected collectedData) reportDa records = append(records, others) } - return reportData{ + return ReportData{ collectedData: records, normalizedSQLMap: collected.normalizedSQLMap, normalizedPlanMap: collected.normalizedPlanMap, } } -func (tsr *RemoteTopSQLReporter) doReport(data reportData) { +func (tsr *RemoteTopSQLReporter) doReport(data ReportData) { defer util.Recover("top-sql", "doReport", nil, false) if !data.hasData() { return } - agentAddr := config.GetGlobalConfig().TopSQL.ReceiverAddress timeout := reportTimeout failpoint.Inject("resetTimeoutForTest", func(val failpoint.Value) { if val.(bool) { @@ -548,14 +546,8 @@ func (tsr *RemoteTopSQLReporter) doReport(data reportData) { } } }) - ctx, cancel := context.WithTimeout(tsr.ctx, timeout) - start := time.Now() - err := tsr.client.Send(ctx, agentAddr, data) - if err != nil { - logutil.BgLogger().Warn("[top-sql] client failed to send data", zap.Error(err)) - reportAllDurationFailedHistogram.Observe(time.Since(start).Seconds()) - } else { - reportAllDurationSuccHistogram.Observe(time.Since(start).Seconds()) + deadline := time.Now().Add(timeout) + if err := tsr.dataSink.TrySend(data, deadline); err != nil { + logutil.BgLogger().Warn("[top-sql] failed to send data to datasink", zap.Error(err)) } - cancel() } diff --git a/util/topsql/reporter/reporter_test.go b/util/topsql/reporter/reporter_test.go index ab146930b85a7..33d048fcf1d86 100644 --- a/util/topsql/reporter/reporter_test.go +++ b/util/topsql/reporter/reporter_test.go @@ -72,7 +72,7 @@ func setupRemoteTopSQLReporter(maxStatementsNum, interval int, addr string) *Rem conf.TopSQL.ReceiverAddress = addr }) - rc := NewGRPCReportClient(mockPlanBinaryDecoderFunc) + rc := NewSingleTargetDataSink(mockPlanBinaryDecoderFunc) ts := NewRemoteTopSQLReporter(rc) return ts } diff --git a/util/topsql/reporter/client.go b/util/topsql/reporter/single_target.go similarity index 55% rename from util/topsql/reporter/client.go rename to util/topsql/reporter/single_target.go index 994189250e52b..5ed12c11853d8 100644 --- a/util/topsql/reporter/client.go +++ b/util/topsql/reporter/single_target.go @@ -16,10 +16,12 @@ package reporter import ( "context" + "errors" "math" "sync" "time" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" @@ -27,36 +29,126 @@ import ( "google.golang.org/grpc/backoff" ) -// ReportClient send data to the target server. -type ReportClient interface { - Send(ctx context.Context, addr string, data reportData) error - Close() -} +// SingleTargetDataSink reports data to grpc servers. +type SingleTargetDataSink struct { + ctx context.Context + cancel context.CancelFunc -// GRPCReportClient reports data to grpc servers. -type GRPCReportClient struct { curRPCAddr string conn *grpc.ClientConn + sendTaskCh chan sendTask + // calling decodePlan this can take a while, so should not block critical paths decodePlan planBinaryDecodeFunc } -// NewGRPCReportClient returns a new GRPCReportClient -func NewGRPCReportClient(decodePlan planBinaryDecodeFunc) *GRPCReportClient { - return &GRPCReportClient{ +// NewSingleTargetDataSink returns a new SingleTargetDataSink +func NewSingleTargetDataSink(decodePlan planBinaryDecodeFunc) *SingleTargetDataSink { + ctx, cancel := context.WithCancel(context.Background()) + dataSink := &SingleTargetDataSink{ + ctx: ctx, + cancel: cancel, + + curRPCAddr: "", + conn: nil, + sendTaskCh: make(chan sendTask, 1), + decodePlan: decodePlan, } + go dataSink.recoverRun() + return dataSink +} + +// recoverRun will run until SingleTargetDataSink is closed. +func (ds *SingleTargetDataSink) recoverRun() { + for ds.run() { + } +} + +func (ds *SingleTargetDataSink) run() (rerun bool) { + defer func() { + r := recover() + if r != nil { + logutil.BgLogger().Error("panic in SingleTargetDataSink, rerun", + zap.Reflect("r", r), + zap.Stack("stack trace")) + rerun = true + } + }() + + for { + var task sendTask + select { + case <-ds.ctx.Done(): + return false + case task = <-ds.sendTaskCh: + } + + targetRPCAddr := config.GetGlobalConfig().TopSQL.ReceiverAddress + if targetRPCAddr == "" { + continue + } + + ctx, cancel := context.WithDeadline(context.Background(), task.deadline) + start := time.Now() + err := ds.doSend(ctx, targetRPCAddr, task.data) + cancel() + if err != nil { + logutil.BgLogger().Warn("[top-sql] single target data sink failed to send data to receiver", zap.Error(err)) + reportAllDurationFailedHistogram.Observe(time.Since(start).Seconds()) + } else { + reportAllDurationSuccHistogram.Observe(time.Since(start).Seconds()) + } + } } -var _ ReportClient = &GRPCReportClient{} +var _ DataSink = &SingleTargetDataSink{} -// Send implements the ReportClient interface. -// Currently the implementation will establish a new connection every time, which is suitable for a per-minute sending period -func (r *GRPCReportClient) Send(ctx context.Context, targetRPCAddr string, data reportData) error { - if targetRPCAddr == "" { +// TrySend implements the DataSink interface. +// Currently the implementation will establish a new connection every time, +// which is suitable for a per-minute sending period +func (ds *SingleTargetDataSink) TrySend(data ReportData, deadline time.Time) error { + select { + case ds.sendTaskCh <- sendTask{data: data, deadline: deadline}: return nil + case <-ds.ctx.Done(): + return ds.ctx.Err() + default: + ignoreReportChannelFullCounter.Inc() + return errors.New("the channel of single target dataSink is full") + } +} + +// IsPaused implements the DataSink interface. +func (ds *SingleTargetDataSink) IsPaused() bool { + return len(config.GetGlobalConfig().TopSQL.ReceiverAddress) == 0 +} + +// IsDown implements the DataSink interface. +func (ds *SingleTargetDataSink) IsDown() bool { + select { + case <-ds.ctx.Done(): + return true + default: + return false + } +} + +// Close uses to close grpc connection. +func (ds *SingleTargetDataSink) Close() { + ds.cancel() + if ds.conn == nil { + return + } + err := ds.conn.Close() + if err != nil { + logutil.BgLogger().Warn("[top-sql] single target dataSink close connection failed", zap.Error(err)) } - err := r.tryEstablishConnection(ctx, targetRPCAddr) + ds.conn = nil +} + +func (ds *SingleTargetDataSink) doSend(ctx context.Context, addr string, data ReportData) error { + err := ds.tryEstablishConnection(ctx, addr) if err != nil { return err } @@ -67,15 +159,15 @@ func (r *GRPCReportClient) Send(ctx context.Context, targetRPCAddr string, data go func() { defer wg.Done() - errCh <- r.sendBatchSQLMeta(ctx, data.normalizedSQLMap) + errCh <- ds.sendBatchSQLMeta(ctx, data.normalizedSQLMap) }() go func() { defer wg.Done() - errCh <- r.sendBatchPlanMeta(ctx, data.normalizedPlanMap) + errCh <- ds.sendBatchPlanMeta(ctx, data.normalizedPlanMap) }() go func() { defer wg.Done() - errCh <- r.sendBatchCPUTimeRecord(ctx, data.collectedData) + errCh <- ds.sendBatchCPUTimeRecord(ctx, data.collectedData) }() wg.Wait() close(errCh) @@ -87,25 +179,13 @@ func (r *GRPCReportClient) Send(ctx context.Context, targetRPCAddr string, data return nil } -// Close uses to close grpc connection. -func (r *GRPCReportClient) Close() { - if r.conn == nil { - return - } - err := r.conn.Close() - if err != nil { - logutil.BgLogger().Warn("[top-sql] grpc client close connection failed", zap.Error(err)) - } - r.conn = nil -} - // sendBatchCPUTimeRecord sends a batch of TopSQL records by stream. -func (r *GRPCReportClient) sendBatchCPUTimeRecord(ctx context.Context, records []*dataPoints) error { +func (ds *SingleTargetDataSink) sendBatchCPUTimeRecord(ctx context.Context, records []*dataPoints) error { if len(records) == 0 { return nil } start := time.Now() - client := tipb.NewTopSQLAgentClient(r.conn) + client := tipb.NewTopSQLAgentClient(ds.conn) stream, err := client.ReportCPUTimeRecords(ctx) if err != nil { return err @@ -133,9 +213,9 @@ func (r *GRPCReportClient) sendBatchCPUTimeRecord(ctx context.Context, records [ } // sendBatchSQLMeta sends a batch of SQL metas by stream. -func (r *GRPCReportClient) sendBatchSQLMeta(ctx context.Context, sqlMap *sync.Map) error { +func (ds *SingleTargetDataSink) sendBatchSQLMeta(ctx context.Context, sqlMap *sync.Map) error { start := time.Now() - client := tipb.NewTopSQLAgentClient(r.conn) + client := tipb.NewTopSQLAgentClient(ds.conn) stream, err := client.ReportSQLMeta(ctx) if err != nil { return err @@ -169,16 +249,16 @@ func (r *GRPCReportClient) sendBatchSQLMeta(ctx context.Context, sqlMap *sync.Ma } // sendBatchPlanMeta sends a batch of SQL metas by stream. -func (r *GRPCReportClient) sendBatchPlanMeta(ctx context.Context, planMap *sync.Map) error { +func (ds *SingleTargetDataSink) sendBatchPlanMeta(ctx context.Context, planMap *sync.Map) error { start := time.Now() - client := tipb.NewTopSQLAgentClient(r.conn) + client := tipb.NewTopSQLAgentClient(ds.conn) stream, err := client.ReportPlanMeta(ctx) if err != nil { return err } cnt := 0 planMap.Range(func(key, value interface{}) bool { - planDecoded, errDecode := r.decodePlan(value.(string)) + planDecoded, errDecode := ds.decodePlan(value.(string)) if errDecode != nil { logutil.BgLogger().Warn("[top-sql] decode plan failed", zap.Error(errDecode)) return true @@ -208,26 +288,26 @@ func (r *GRPCReportClient) sendBatchPlanMeta(ctx context.Context, planMap *sync. } // tryEstablishConnection establishes the gRPC connection if connection is not established. -func (r *GRPCReportClient) tryEstablishConnection(ctx context.Context, targetRPCAddr string) (err error) { - if r.curRPCAddr == targetRPCAddr && r.conn != nil { +func (ds *SingleTargetDataSink) tryEstablishConnection(ctx context.Context, targetRPCAddr string) (err error) { + if ds.curRPCAddr == targetRPCAddr && ds.conn != nil { // Address is not changed, skip. return nil } - if r.conn != nil { - err := r.conn.Close() - logutil.BgLogger().Warn("[top-sql] grpc client close connection failed", zap.Error(err)) + if ds.conn != nil { + err := ds.conn.Close() + logutil.BgLogger().Warn("[top-sql] grpc dataSink close connection failed", zap.Error(err)) } - r.conn, err = r.dial(ctx, targetRPCAddr) + ds.conn, err = ds.dial(ctx, targetRPCAddr) if err != nil { return err } - r.curRPCAddr = targetRPCAddr + ds.curRPCAddr = targetRPCAddr return nil } -func (r *GRPCReportClient) dial(ctx context.Context, targetRPCAddr string) (*grpc.ClientConn, error) { +func (ds *SingleTargetDataSink) dial(ctx context.Context, targetRPCAddr string) (*grpc.ClientConn, error) { dialCtx, cancel := context.WithTimeout(ctx, dialTimeout) defer cancel() return grpc.DialContext( @@ -250,3 +330,8 @@ func (r *GRPCReportClient) dial(ctx context.Context, targetRPCAddr string) (*grp }), ) } + +type sendTask struct { + data ReportData + deadline time.Time +} diff --git a/util/topsql/topsql.go b/util/topsql/topsql.go index 5a458b4a21f3f..67ceb242039b5 100644 --- a/util/topsql/topsql.go +++ b/util/topsql/topsql.go @@ -40,8 +40,8 @@ var globalTopSQLReport reporter.TopSQLReporter // SetupTopSQL sets up the top-sql worker. func SetupTopSQL() { - rc := reporter.NewGRPCReportClient(plancodec.DecodeNormalizedPlan) - globalTopSQLReport = reporter.NewRemoteTopSQLReporter(rc) + ds := reporter.NewSingleTargetDataSink(plancodec.DecodeNormalizedPlan) + globalTopSQLReport = reporter.NewRemoteTopSQLReporter(ds) tracecpu.GlobalSQLCPUProfiler.SetCollector(globalTopSQLReport) tracecpu.GlobalSQLCPUProfiler.Run() } diff --git a/util/topsql/topsql_test.go b/util/topsql/topsql_test.go index 616095231bfbe..d9ec6799194d7 100644 --- a/util/topsql/topsql_test.go +++ b/util/topsql/topsql_test.go @@ -115,7 +115,7 @@ func TestTopSQLReporter(t *testing.T) { conf.TopSQL.ReceiverAddress = server.Address() }) - client := reporter.NewGRPCReportClient(mockPlanBinaryDecoderFunc) + client := reporter.NewSingleTargetDataSink(mockPlanBinaryDecoderFunc) report := reporter.NewRemoteTopSQLReporter(client) defer report.Close() From 2f42f7c0f698d088c095cde7ec1cb52014238979 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Tue, 14 Dec 2021 19:36:35 +0800 Subject: [PATCH 3/8] planner: unify the argument of stats functions to use SessionCtx instead of StatementContext (#30668) --- planner/core/find_best_task.go | 17 +++-- planner/core/logical_plans.go | 9 +-- planner/core/rule_partition_processor.go | 6 +- planner/core/stats.go | 2 +- planner/util/path.go | 2 +- statistics/handle/ddl_serial_test.go | 14 ++-- statistics/handle/handle_test.go | 13 ++-- statistics/handle/update.go | 38 +++++++++-- statistics/handle/update_test.go | 2 +- statistics/histogram.go | 53 ++++++++------- statistics/histogram_test.go | 4 +- statistics/selectivity.go | 22 +++---- statistics/selectivity_serial_test.go | 40 +++++------ statistics/statistics_test.go | 73 ++++++++++---------- statistics/table.go | 84 +++++++++++++----------- util/ranger/types.go | 4 +- 16 files changed, 204 insertions(+), 179 deletions(-) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index c23614e7c5935..59a182f9f8e7c 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -27,7 +27,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/planner/util" - "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" tidbutil "github.com/pingcap/tidb/util" @@ -1478,7 +1478,7 @@ func getMostCorrCol4Handle(exprs []expression.Expression, histColl *statistics.T } // getColumnRangeCounts estimates row count for each range respectively. -func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) { +func getColumnRangeCounts(sctx sessionctx.Context, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) { var err error var count float64 rangeCounts := make([]float64, len(ranges)) @@ -1488,13 +1488,13 @@ func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*r if idxHist == nil || idxHist.IsInvalid(false) { return nil, false } - count, err = histColl.GetRowCountByIndexRanges(sc, idxID, []*ranger.Range{ran}) + count, err = histColl.GetRowCountByIndexRanges(sctx, idxID, []*ranger.Range{ran}) } else { colHist, ok := histColl.Columns[colID] - if !ok || colHist.IsInvalid(sc, false) { + if !ok || colHist.IsInvalid(sctx, false) { return nil, false } - count, err = histColl.GetRowCountByColumnRanges(sc, colID, []*ranger.Range{ran}) + count, err = histColl.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{ran}) } if err != nil { return nil, false @@ -1564,7 +1564,6 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre if len(accessConds) == 0 { return 0, false, corr } - sc := ds.ctx.GetSessionVars().StmtCtx ranges, err := ranger.BuildColumnRange(accessConds, ds.ctx, col.RetType, types.UnspecifiedLength) if len(ranges) == 0 || err != nil { return 0, err == nil, corr @@ -1573,7 +1572,7 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre if !idxExists { idxID = -1 } - rangeCounts, ok := getColumnRangeCounts(sc, colID, ranges, ds.tableStats.HistColl, idxID) + rangeCounts, ok := getColumnRangeCounts(ds.ctx, colID, ranges, ds.tableStats.HistColl, idxID) if !ok { return 0, false, corr } @@ -1583,9 +1582,9 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre } var rangeCount float64 if idxExists { - rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, idxID, convertedRanges) + rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, idxID, convertedRanges) } else { - rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(sc, colID, convertedRanges) + rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(ds.ctx, colID, convertedRanges) } if err != nil { return 0, false, corr diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 212f10d65346a..5fe0426b5c15b 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -726,7 +726,6 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co if len(conds) == 0 { return nil } - sc := ds.ctx.GetSessionVars().StmtCtx if len(path.IdxCols) != 0 { res, err := ranger.DetachCondAndBuildRangeForIndex(ds.ctx, conds, path.IdxCols, path.IdxColLens) if err != nil { @@ -744,7 +743,7 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co path.ConstCols[i] = res.ColumnValues[i] != nil } } - path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges) + path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges) if err != nil { return err } @@ -785,7 +784,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres return ds.deriveCommonHandleTablePathStats(path, conds, isIm) } var err error - sc := ds.ctx.GetSessionVars().StmtCtx path.CountAfterAccess = float64(ds.statisticTable.Count) path.TableFilters = conds var pkCol *expression.Column @@ -848,7 +846,7 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres if err != nil { return err } - path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(sc, pkCol.ID, path.Ranges) + path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(ds.ctx, pkCol.ID, path.Ranges) // If the `CountAfterAccess` is less than `stats.RowCount`, there must be some inconsistent stats info. // We prefer the `stats.RowCount` because it could use more stats info to calculate the selectivity. if path.CountAfterAccess < ds.stats.RowCount && !isIm { @@ -858,7 +856,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres } func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Expression) error { - sc := ds.ctx.GetSessionVars().StmtCtx path.Ranges = ranger.FullRange() path.CountAfterAccess = float64(ds.statisticTable.Count) path.IdxCols, path.IdxColLens = expression.IndexInfo2PrefixCols(ds.Columns, ds.schema.Columns, path.Index) @@ -900,7 +897,7 @@ func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Ex path.ConstCols[i] = res.ColumnValues[i] != nil } } - path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges) + path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges) if err != nil { return err } diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 7c3bbb565c69d..bb57b0fac33da 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -140,7 +140,7 @@ func (s *partitionProcessor) findUsedPartitions(ctx sessionctx.Context, tbl tabl ranges := detachedResult.Ranges used := make([]int, 0, len(ranges)) for _, r := range ranges { - if r.IsPointNullable(ctx.GetSessionVars().StmtCtx) { + if r.IsPointNullable(ctx) { if !r.HighVal[0].IsNull() { if len(r.HighVal) != len(partIdx) { used = []int{-1} @@ -473,7 +473,7 @@ func (l *listPartitionPruner) locateColumnPartitionsByCondition(cond expression. return nil, true, nil } var locations []tables.ListPartitionLocation - if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) { + if r.IsPointNullable(l.ctx) { location, err := colPrune.LocatePartition(sc, r.HighVal[0]) if types.ErrOverflow.Equal(err) { return nil, true, nil // return full-scan if over-flow @@ -555,7 +555,7 @@ func (l *listPartitionPruner) findUsedListPartitions(conds []expression.Expressi } used := make(map[int]struct{}, len(ranges)) for _, r := range ranges { - if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) { + if r.IsPointNullable(l.ctx) { if len(r.HighVal) != len(exprCols) { return l.fullRange, nil } diff --git a/planner/core/stats.go b/planner/core/stats.go index 14a6a11a2c2d4..2e7fd14a67b8d 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -253,7 +253,7 @@ func (ds *DataSource) deriveStatsByFilter(conds expression.CNFExprs, filledPaths } stats := ds.tableStats.Scale(selectivity) if ds.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 { - stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx.GetSessionVars().StmtCtx, nodes) + stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx, nodes) } return stats } diff --git a/planner/util/path.go b/planner/util/path.go index 76e6c5173793d..9a0b4207d1314 100644 --- a/planner/util/path.go +++ b/planner/util/path.go @@ -153,7 +153,7 @@ func isColEqCorColOrConstant(ctx sessionctx.Context, filter expression.Expressio func (path *AccessPath) OnlyPointRange(sctx sessionctx.Context) bool { if path.IsIntHandlePath { for _, ran := range path.Ranges { - if !ran.IsPointNullable(sctx.GetSessionVars().StmtCtx) { + if !ran.IsPointNullable(sctx) { return false } } diff --git a/statistics/handle/ddl_serial_test.go b/statistics/handle/ddl_serial_test.go index 76121694338df..91a3a244cb17d 100644 --- a/statistics/handle/ddl_serial_test.go +++ b/statistics/handle/ddl_serial_test.go @@ -18,10 +18,10 @@ import ( "testing" "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/mock" "github.com/stretchr/testify/require" ) @@ -51,10 +51,10 @@ func TestDDLAfterLoad(t *testing.T) { require.NoError(t, err) tableInfo = tbl.Meta() - sc := new(stmtctx.StatementContext) - count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID) + sctx := mock.NewContext() + count := statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID) require.Equal(t, 0.0, count) - count = statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID) + count = statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID) require.Equal(t, 333, int(count)) } @@ -131,11 +131,11 @@ func TestDDLHistogram(t *testing.T) { tableInfo = tbl.Meta() statsTbl = do.StatsHandle().GetTableStats(tableInfo) require.False(t, statsTbl.Pseudo) - sc := new(stmtctx.StatementContext) - count, err := statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(0), tableInfo.Columns[3].ID) + sctx := mock.NewContext() + count, err := statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(0), tableInfo.Columns[3].ID) require.NoError(t, err) require.Equal(t, float64(2), count) - count, err = statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(1), tableInfo.Columns[3].ID) + count, err = statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1), tableInfo.Columns[3].ID) require.NoError(t, err) require.Equal(t, float64(0), count) diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 23b2de4333af8..70ec989f7bca6 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -32,7 +32,6 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/session" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle" @@ -40,6 +39,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/israce" + "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/ranger" "github.com/pingcap/tidb/util/testkit" "github.com/tikv/client-go/v2/oracle" @@ -267,8 +267,7 @@ func (s *testStatsSuite) TestEmptyTable(c *C) { c.Assert(err, IsNil) tableInfo := tbl.Meta() statsTbl := do.StatsHandle().GetTableStats(tableInfo) - sc := new(stmtctx.StatementContext) - count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(1), tableInfo.Columns[0].ID) + count := statsTbl.ColumnGreaterRowCount(mock.NewContext(), types.NewDatum(1), tableInfo.Columns[0].ID) c.Assert(count, Equals, 0.0) } @@ -285,14 +284,14 @@ func (s *testStatsSuite) TestColumnIDs(c *C) { c.Assert(err, IsNil) tableInfo := tbl.Meta() statsTbl := do.StatsHandle().GetTableStats(tableInfo) - sc := new(stmtctx.StatementContext) + sctx := mock.NewContext() ran := &ranger.Range{ LowVal: []types.Datum{types.MinNotNullDatum()}, HighVal: []types.Datum{types.NewIntDatum(2)}, LowExclude: false, HighExclude: true, } - count, err := statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran}) + count, err := statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran}) c.Assert(err, IsNil) c.Assert(count, Equals, float64(1)) @@ -307,7 +306,7 @@ func (s *testStatsSuite) TestColumnIDs(c *C) { tableInfo = tbl.Meta() statsTbl = do.StatsHandle().GetTableStats(tableInfo) // At that time, we should get c2's stats instead of c1's. - count, err = statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran}) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran}) c.Assert(err, IsNil) c.Assert(count, Equals, 0.0) } @@ -614,7 +613,7 @@ func (s *testStatsSuite) TestLoadStats(c *C) { c.Assert(hg.Len(), Equals, 0) cms = stat.Columns[tableInfo.Columns[2].ID].CMSketch c.Assert(cms, IsNil) - _, err = stat.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo.Columns[2].ID) + _, err = stat.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo.Columns[2].ID) c.Assert(err, IsNil) c.Assert(h.LoadNeededHistograms(), IsNil) stat = h.GetTableStats(tableInfo) diff --git a/statistics/handle/update.go b/statistics/handle/update.go index e154755b5cc8d..a36f12bbdd7d2 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" @@ -1266,7 +1267,18 @@ func (h *Handle) RecalculateExpectCount(q *statistics.QueryFeedback) error { return nil } - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + se, err := h.pool.Get() + if err != nil { + return err + } + sctx := se.(sessionctx.Context) + timeZone := sctx.GetSessionVars().StmtCtx.TimeZone + defer func() { + sctx.GetSessionVars().StmtCtx.TimeZone = timeZone + h.pool.Put(se) + }() + sctx.GetSessionVars().StmtCtx.TimeZone = time.UTC + ranges, err := q.DecodeToRanges(isIndex) if err != nil { return errors.Trace(err) @@ -1274,10 +1286,10 @@ func (h *Handle) RecalculateExpectCount(q *statistics.QueryFeedback) error { expected := 0.0 if isIndex { idx := t.Indices[id] - expected, err = idx.GetRowCount(sc, nil, ranges, t.Count) + expected, err = idx.GetRowCount(sctx, nil, ranges, t.Count) } else { c := t.Columns[id] - expected, err = c.GetColumnRowCount(sc, ranges, t.Count, true) + expected, err = c.GetColumnRowCount(sctx, ranges, t.Count, true) } q.Expected = int64(expected) return err @@ -1354,7 +1366,20 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics if !ok { return nil } - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + + se, err := h.pool.Get() + if err != nil { + return err + } + sctx := se.(sessionctx.Context) + sc := sctx.GetSessionVars().StmtCtx + timeZone := sc.TimeZone + defer func() { + sctx.GetSessionVars().StmtCtx.TimeZone = timeZone + h.pool.Put(se) + }() + sc.TimeZone = time.UTC + if idx.CMSketch == nil || idx.StatsVer < statistics.Version1 { return h.DumpFeedbackToKV(q) } @@ -1369,7 +1394,6 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics if rangePosition == 0 || rangePosition == len(ran.LowVal) { continue } - bytes, err := codec.EncodeKey(sc, nil, ran.LowVal[:rangePosition]...) if err != nil { logutil.BgLogger().Debug("encode keys fail", zap.Error(err)) @@ -1385,12 +1409,12 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics rangeFB := &statistics.QueryFeedback{PhysicalID: q.PhysicalID} // prefer index stats over column stats if idx := t.IndexStartWithColumn(colName); idx != nil && idx.Histogram.Len() != 0 { - rangeCount, err = t.GetRowCountByIndexRanges(sc, idx.ID, []*ranger.Range{rang}) + rangeCount, err = t.GetRowCountByIndexRanges(sctx, idx.ID, []*ranger.Range{rang}) rangeFB.Tp, rangeFB.Hist = statistics.IndexType, &idx.Histogram } else if col := t.ColumnByName(colName); col != nil && col.Histogram.Len() != 0 { err = convertRangeType(rang, col.Tp, time.UTC) if err == nil { - rangeCount, err = t.GetRowCountByColumnRanges(sc, col.ID, []*ranger.Range{rang}) + rangeCount, err = t.GetRowCountByColumnRanges(sctx, col.ID, []*ranger.Range{rang}) rangeFB.Tp, rangeFB.Hist = statistics.ColType, &col.Histogram } } else { diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index fe4904e254be5..3d41f92701593 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -149,7 +149,7 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { c.Assert(stats1.Count, Equals, int64(rowCount1*2)) // Test IncreaseFactor. - count, err := stats1.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo1.Columns[0].ID) + count, err := stats1.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo1.Columns[0].ID) c.Assert(err, IsNil) c.Assert(count, Equals, float64(rowCount1*2)) diff --git a/statistics/histogram.go b/statistics/histogram.go index a61f1d1405f59..5e1788da7a1ac 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/tablecodec" @@ -506,7 +507,7 @@ func (hg *Histogram) BetweenRowCount(a, b types.Datum) float64 { } // BetweenRowCount estimates the row count for interval [l, r). -func (c *Column) BetweenRowCount(sc *stmtctx.StatementContext, l, r types.Datum, lowEncoded, highEncoded []byte) float64 { +func (c *Column) BetweenRowCount(sctx sessionctx.Context, l, r types.Datum, lowEncoded, highEncoded []byte) float64 { histBetweenCnt := c.Histogram.BetweenRowCount(l, r) if c.StatsVer <= Version1 { return histBetweenCnt @@ -1067,17 +1068,17 @@ var HistogramNeededColumns = neededColumnMap{cols: map[tableColumnID]struct{}{}} // IsInvalid checks if this column is invalid. If this column has histogram but not loaded yet, then we mark it // as need histogram. -func (c *Column) IsInvalid(sc *stmtctx.StatementContext, collPseudo bool) bool { +func (c *Column) IsInvalid(sctx sessionctx.Context, collPseudo bool) bool { if collPseudo && c.NotAccurate() { return true } - if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sc != nil { + if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sctx != nil && sctx.GetSessionVars().StmtCtx != nil { HistogramNeededColumns.insert(tableColumnID{TableID: c.PhysicalID, ColumnID: c.Info.ID}) } return c.TotalRowCount() == 0 || (c.Histogram.NDV > 0 && c.notNullCount() == 0) } -func (c *Column) equalRowCount(sc *stmtctx.StatementContext, val types.Datum, encodedVal []byte, realtimeRowCount int64) (float64, error) { +func (c *Column) equalRowCount(sctx sessionctx.Context, val types.Datum, encodedVal []byte, realtimeRowCount int64) (float64, error) { if val.IsNull() { return float64(c.NullCount), nil } @@ -1090,7 +1091,7 @@ func (c *Column) equalRowCount(sc *stmtctx.StatementContext, val types.Datum, en return outOfRangeEQSelectivity(c.Histogram.NDV, realtimeRowCount, int64(c.TotalRowCount())) * c.TotalRowCount(), nil } if c.CMSketch != nil { - count, err := queryValue(sc, c.CMSketch, c.TopN, val) + count, err := queryValue(sctx.GetSessionVars().StmtCtx, c.CMSketch, c.TopN, val) return float64(count), errors.Trace(err) } histRowCount, _ := c.Histogram.equalRowCount(val, false) @@ -1123,7 +1124,8 @@ func (c *Column) equalRowCount(sc *stmtctx.StatementContext, val types.Datum, en } // GetColumnRowCount estimates the row count by a slice of Range. -func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*ranger.Range, realtimeRowCount int64, pkIsHandle bool) (float64, error) { +func (c *Column) GetColumnRowCount(sctx sessionctx.Context, ranges []*ranger.Range, realtimeRowCount int64, pkIsHandle bool) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx var rowCount float64 for _, rg := range ranges { highVal := *rg.HighVal[0].Clone() @@ -1155,7 +1157,7 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range continue } var cnt float64 - cnt, err = c.equalRowCount(sc, lowVal, lowEncoded, realtimeRowCount) + cnt, err = c.equalRowCount(sctx, lowVal, lowEncoded, realtimeRowCount) if err != nil { return 0, errors.Trace(err) } @@ -1173,7 +1175,7 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range // case 2: it's a small range && using ver1 stats if rangeVals != nil { for _, val := range rangeVals { - cnt, err := c.equalRowCount(sc, val, lowEncoded, realtimeRowCount) + cnt, err := c.equalRowCount(sctx, val, lowEncoded, realtimeRowCount) if err != nil { return 0, err } @@ -1187,12 +1189,12 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range } // case 3: it's an interval - cnt := c.BetweenRowCount(sc, lowVal, highVal, lowEncoded, highEncoded) + cnt := c.BetweenRowCount(sctx, lowVal, highVal, lowEncoded, highEncoded) // `betweenRowCount` returns count for [l, h) range, we adjust cnt for boundaries here. // Note that, `cnt` does not include null values, we need specially handle cases // where null is the lower bound. if rg.LowExclude && !lowVal.IsNull() { - lowCnt, err := c.equalRowCount(sc, lowVal, lowEncoded, realtimeRowCount) + lowCnt, err := c.equalRowCount(sctx, lowVal, lowEncoded, realtimeRowCount) if err != nil { return 0, errors.Trace(err) } @@ -1202,7 +1204,7 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range cnt += float64(c.NullCount) } if !rg.HighExclude { - highCnt, err := c.equalRowCount(sc, highVal, highEncoded, realtimeRowCount) + highCnt, err := c.equalRowCount(sctx, highVal, highEncoded, realtimeRowCount) if err != nil { return 0, errors.Trace(err) } @@ -1326,7 +1328,8 @@ func (idx *Index) QueryBytes(d []byte) uint64 { // GetRowCount returns the row count of the given ranges. // It uses the modifyCount to adjust the influence of modifications on the table. -func (idx *Index) GetRowCount(sc *stmtctx.StatementContext, coll *HistColl, indexRanges []*ranger.Range, realtimeRowCount int64) (float64, error) { +func (idx *Index) GetRowCount(sctx sessionctx.Context, coll *HistColl, indexRanges []*ranger.Range, realtimeRowCount int64) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx totalCount := float64(0) isSingleCol := len(idx.Info.Columns) == 1 for _, indexRange := range indexRanges { @@ -1377,7 +1380,7 @@ func (idx *Index) GetRowCount(sc *stmtctx.StatementContext, coll *HistColl, inde // If the first column's range is point. if rangePosition := GetOrdinalOfRangeCond(sc, indexRange); rangePosition > 0 && idx.StatsVer >= Version2 && coll != nil { var expBackoffSel float64 - expBackoffSel, expBackoffSuccess, err = idx.expBackoffEstimation(sc, coll, indexRange) + expBackoffSel, expBackoffSuccess, err = idx.expBackoffEstimation(sctx, coll, indexRange) if err != nil { return 0, err } @@ -1408,7 +1411,7 @@ func (idx *Index) GetRowCount(sc *stmtctx.StatementContext, coll *HistColl, inde } // expBackoffEstimation estimate the multi-col cases following the Exponential Backoff. See comment below for details. -func (idx *Index) expBackoffEstimation(sc *stmtctx.StatementContext, coll *HistColl, indexRange *ranger.Range) (float64, bool, error) { +func (idx *Index) expBackoffEstimation(sctx sessionctx.Context, coll *HistColl, indexRange *ranger.Range) (float64, bool, error) { tmpRan := []*ranger.Range{ { LowVal: make([]types.Datum, 1), @@ -1435,9 +1438,9 @@ func (idx *Index) expBackoffEstimation(sc *stmtctx.StatementContext, coll *HistC err error ) if anotherIdxID, ok := coll.ColID2IdxID[colID]; ok && anotherIdxID != idx.ID { - count, err = coll.GetRowCountByIndexRanges(sc, anotherIdxID, tmpRan) - } else if col, ok := coll.Columns[colID]; ok && !col.IsInvalid(sc, coll.Pseudo) { - count, err = coll.GetRowCountByColumnRanges(sc, colID, tmpRan) + count, err = coll.GetRowCountByIndexRanges(sctx, anotherIdxID, tmpRan) + } else if col, ok := coll.Columns[colID]; ok && !col.IsInvalid(sctx, coll.Pseudo) { + count, err = coll.GetRowCountByColumnRanges(sctx, colID, tmpRan) } else { continue } @@ -1471,12 +1474,12 @@ func (idx *Index) expBackoffEstimation(sc *stmtctx.StatementContext, coll *HistC return singleColumnEstResults[0] * math.Sqrt(singleColumnEstResults[1]) * math.Sqrt(math.Sqrt(singleColumnEstResults[2])) * math.Sqrt(math.Sqrt(math.Sqrt(singleColumnEstResults[3]))), true, nil } -type countByRangeFunc = func(*stmtctx.StatementContext, int64, []*ranger.Range) (float64, error) +type countByRangeFunc = func(sessionctx.Context, int64, []*ranger.Range) (float64, error) // newHistogramBySelectivity fulfills the content of new histogram by the given selectivity result. // TODO: Datum is not efficient, try to avoid using it here. // Also, there're redundant calculation with Selectivity(). We need to reduce it too. -func newHistogramBySelectivity(sc *stmtctx.StatementContext, histID int64, oldHist, newHist *Histogram, ranges []*ranger.Range, cntByRangeFunc countByRangeFunc) error { +func newHistogramBySelectivity(sctx sessionctx.Context, histID int64, oldHist, newHist *Histogram, ranges []*ranger.Range, cntByRangeFunc countByRangeFunc) error { cntPerVal := int64(oldHist.AvgCountPerNotNullValue(int64(oldHist.TotalRowCount()))) var totCnt int64 for boundIdx, ranIdx, highRangeIdx := 0, 0, 0; boundIdx < oldHist.Bounds.NumRows() && ranIdx < len(ranges); boundIdx, ranIdx = boundIdx+2, highRangeIdx { @@ -1489,7 +1492,7 @@ func newHistogramBySelectivity(sc *stmtctx.StatementContext, histID int64, oldHi if ranIdx == highRangeIdx { continue } - cnt, err := cntByRangeFunc(sc, histID, ranges[ranIdx:highRangeIdx]) + cnt, err := cntByRangeFunc(sctx, histID, ranges[ranIdx:highRangeIdx]) // This should not happen. if err != nil { return err @@ -1565,7 +1568,7 @@ func (idx *Index) newIndexBySelectivity(sc *stmtctx.StatementContext, statsNode } // NewHistCollBySelectivity creates new HistColl by the given statsNodes. -func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, statsNodes []*StatsNode) *HistColl { +func (coll *HistColl) NewHistCollBySelectivity(sctx sessionctx.Context, statsNodes []*StatsNode) *HistColl { newColl := &HistColl{ Columns: make(map[int64]*Column), Indices: make(map[int64]*Index), @@ -1579,7 +1582,7 @@ func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, sta if !ok { continue } - newIdxHist, err := idxHist.newIndexBySelectivity(sc, node) + newIdxHist, err := idxHist.newIndexBySelectivity(sctx.GetSessionVars().StmtCtx, node) if err != nil { logutil.BgLogger().Warn("[Histogram-in-plan]: something wrong happened when calculating row count, "+ "failed to build histogram for index %v of table %v", @@ -1601,7 +1604,7 @@ func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, sta } newCol.Histogram = *NewHistogram(oldCol.ID, int64(float64(oldCol.Histogram.NDV)*node.Selectivity), 0, 0, oldCol.Tp, chunk.InitialCapacity, 0) var err error - splitRanges, ok := oldCol.Histogram.SplitRange(sc, node.Ranges, false) + splitRanges, ok := oldCol.Histogram.SplitRange(sctx.GetSessionVars().StmtCtx, node.Ranges, false) if !ok { logutil.BgLogger().Warn("[Histogram-in-plan]: the type of histogram and ranges mismatch") continue @@ -1619,9 +1622,9 @@ func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, sta } } if oldCol.IsHandle { - err = newHistogramBySelectivity(sc, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByIntColumnRanges) + err = newHistogramBySelectivity(sctx, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByIntColumnRanges) } else { - err = newHistogramBySelectivity(sc, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByColumnRanges) + err = newHistogramBySelectivity(sctx, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByColumnRanges) } if err != nil { logutil.BgLogger().Warn("[Histogram-in-plan]: something wrong happened when calculating row count", diff --git a/statistics/histogram_test.go b/statistics/histogram_test.go index 7f5f50e00b91f..15e4d696de1ee 100644 --- a/statistics/histogram_test.go +++ b/statistics/histogram_test.go @@ -92,7 +92,7 @@ num: 54 lower_bound: kkkkk upper_bound: ooooo repeats: 0 ndv: 0 num: 60 lower_bound: oooooo upper_bound: sssss repeats: 0 ndv: 0 num: 60 lower_bound: ssssssu upper_bound: yyyyy repeats: 0 ndv: 0` - newColl := coll.NewHistCollBySelectivity(sc, []*StatsNode{node, node2}) + newColl := coll.NewHistCollBySelectivity(ctx, []*StatsNode{node, node2}) require.Equal(t, intColResult, newColl.Columns[1].String()) require.Equal(t, stringColResult, newColl.Columns[2].String()) @@ -119,7 +119,7 @@ num: 30 lower_bound: 3 upper_bound: 5 repeats: 10 ndv: 0 num: 30 lower_bound: 9 upper_bound: 11 repeats: 10 ndv: 0 num: 30 lower_bound: 12 upper_bound: 14 repeats: 10 ndv: 0` - newColl = coll.NewHistCollBySelectivity(sc, []*StatsNode{node3}) + newColl = coll.NewHistCollBySelectivity(ctx, []*StatsNode{node3}) require.Equal(t, idxResult, newColl.Indices[0].String()) } diff --git a/statistics/selectivity.go b/statistics/selectivity.go index 86321d561e954..45db1cebf9b1c 100644 --- a/statistics/selectivity.go +++ b/statistics/selectivity.go @@ -27,7 +27,6 @@ import ( "github.com/pingcap/tidb/parser/mysql" planutil "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" @@ -193,7 +192,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if len(exprs) > 63 || (len(coll.Columns) == 0 && len(coll.Indices) == 0) { ret = pseudoSelectivity(coll, exprs) if sc.EnableOptimizerCETrace { - CETraceExpr(sc, tableID, "Table Stats-Pseudo-Expression", expression.ComposeCNFCondition(ctx, exprs...), ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Pseudo-Expression", expression.ComposeCNFCondition(ctx, exprs...), ret*float64(coll.Count)) } return ret, nil, nil } @@ -210,7 +209,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp continue } - if colHist := coll.Columns[c.UniqueID]; colHist == nil || colHist.IsInvalid(sc, coll.Pseudo) { + if colHist := coll.Columns[c.UniqueID]; colHist == nil || colHist.IsInvalid(ctx, coll.Pseudo) { ret *= 1.0 / pseudoEqualRate continue } @@ -236,14 +235,14 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if colInfo.IsHandle { nodes[len(nodes)-1].Tp = PkType var cnt float64 - cnt, err = coll.GetRowCountByIntColumnRanges(sc, id, ranges) + cnt, err = coll.GetRowCountByIntColumnRanges(ctx, id, ranges) if err != nil { return 0, nil, errors.Trace(err) } nodes[len(nodes)-1].Selectivity = cnt / float64(coll.Count) continue } - cnt, err := coll.GetRowCountByColumnRanges(sc, id, ranges) + cnt, err := coll.GetRowCountByColumnRanges(ctx, id, ranges) if err != nil { return 0, nil, errors.Trace(err) } @@ -274,7 +273,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if err != nil { return 0, nil, errors.Trace(err) } - cnt, err := coll.GetRowCountByIndexRanges(sc, id, ranges) + cnt, err := coll.GetRowCountByIndexRanges(ctx, id, ranges) if err != nil { return 0, nil, errors.Trace(err) } @@ -314,7 +313,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp } } expr := expression.ComposeCNFCondition(ctx, curExpr...) - CETraceExpr(sc, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) } } @@ -372,7 +371,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp selectivity = selectivity + curSelectivity - selectivity*curSelectivity if sc.EnableOptimizerCETrace { // Tracing for the expression estimation results of this DNF. - CETraceExpr(sc, tableID, "Table Stats-Expression-DNF", scalarCond, selectivity*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-DNF", scalarCond, selectivity*float64(coll.Count)) } } @@ -384,7 +383,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp // Tracing for the expression estimation results after applying the DNF estimation result. curExpr = append(curExpr, remainedExprs[i]) expr := expression.ComposeCNFCondition(ctx, curExpr...) - CETraceExpr(sc, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) } } } @@ -396,7 +395,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if sc.EnableOptimizerCETrace { // Tracing for the expression estimation results after applying the default selectivity. totalExpr := expression.ComposeCNFCondition(ctx, remainedExprs...) - CETraceExpr(sc, tableID, "Table Stats-Expression-CNF", totalExpr, ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-CNF", totalExpr, ret*float64(coll.Count)) } return ret, nodes, nil } @@ -520,7 +519,7 @@ func FindPrefixOfIndexByCol(cols []*expression.Column, idxColIDs []int64, cached } // CETraceExpr appends an expression and related information into CE trace -func CETraceExpr(sc *stmtctx.StatementContext, tableID int64, tp string, expr expression.Expression, rowCount float64) { +func CETraceExpr(sctx sessionctx.Context, tableID int64, tp string, expr expression.Expression, rowCount float64) { exprStr, err := ExprToString(expr) if err != nil { logutil.BgLogger().Debug("[OptimizerTrace] Failed to trace CE of an expression", @@ -533,6 +532,7 @@ func CETraceExpr(sc *stmtctx.StatementContext, tableID int64, tp string, expr ex Expr: exprStr, RowCount: uint64(rowCount), } + sc := sctx.GetSessionVars().StmtCtx sc.OptimizerCETrace = append(sc.OptimizerCETrace, &rec) } diff --git a/statistics/selectivity_serial_test.go b/statistics/selectivity_serial_test.go index 7fdbf09c757dc..a128be3850049 100644 --- a/statistics/selectivity_serial_test.go +++ b/statistics/selectivity_serial_test.go @@ -28,13 +28,13 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/testkit/testdata" "github.com/pingcap/tidb/util/collate" + "github.com/pingcap/tidb/util/mock" "github.com/stretchr/testify/require" ) @@ -125,9 +125,9 @@ func TestOutOfRangeEstimation(t *testing.T) { table, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) require.NoError(t, err) statsTbl := h.GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() col := statsTbl.Columns[table.Meta().Columns[0].ID] - count, err := col.GetColumnRowCount(sc, getRange(900, 900), statsTbl.Count, false) + count, err := col.GetColumnRowCount(sctx, getRange(900, 900), statsTbl.Count, false) require.NoError(t, err) // Because the ANALYZE collect data by random sampling, so the result is not an accurate value. // so we use a range here. @@ -147,7 +147,7 @@ func TestOutOfRangeEstimation(t *testing.T) { statsSuiteData.GetTestCases(t, &input, &output) increasedTblRowCount := int64(float64(statsTbl.Count) * 1.5) for i, ran := range input { - count, err = col.GetColumnRowCount(sc, getRange(ran.Start, ran.End), increasedTblRowCount, false) + count, err = col.GetColumnRowCount(sctx, getRange(ran.Start, ran.End), increasedTblRowCount, false) require.NoError(t, err) testdata.OnRecord(func() { output[i].Start = ran.Start @@ -184,26 +184,26 @@ func TestEstimationForUnknownValues(t *testing.T) { require.NoError(t, err) statsTbl := h.GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() colID := table.Meta().Columns[0].ID - count, err := statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(30, 30)) + count, err := statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(30, 30)) require.NoError(t, err) require.Equal(t, 0.2, count) - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(9, 30)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(9, 30)) require.NoError(t, err) require.Equal(t, 7.2, count) - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(9, math.MaxInt64)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(9, math.MaxInt64)) require.NoError(t, err) require.Equal(t, 7.2, count) idxID := table.Meta().Indices[0].ID - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(30, 30)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(30, 30)) require.NoError(t, err) require.Equal(t, 0.1, count) - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(9, 30)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(9, 30)) require.NoError(t, err) require.Equal(t, 7.0, count) @@ -215,7 +215,7 @@ func TestEstimationForUnknownValues(t *testing.T) { statsTbl = h.GetTableStats(table.Meta()) colID = table.Meta().Columns[0].ID - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(1, 30)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(1, 30)) require.NoError(t, err) require.Equal(t, 0.0, count) @@ -228,12 +228,12 @@ func TestEstimationForUnknownValues(t *testing.T) { statsTbl = h.GetTableStats(table.Meta()) colID = table.Meta().Columns[0].ID - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(2, 2)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(2, 2)) require.NoError(t, err) require.Equal(t, 0.0, count) idxID = table.Meta().Indices[0].ID - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(2, 2)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(2, 2)) require.NoError(t, err) require.Equal(t, 0.0, count) } @@ -252,22 +252,22 @@ func TestEstimationUniqueKeyEqualConds(t *testing.T) { require.NoError(t, err) statsTbl := dom.StatsHandle().GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() idxID := table.Meta().Indices[0].ID - count, err := statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(7, 7)) + count, err := statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(7, 7)) require.NoError(t, err) require.Equal(t, 1.0, count) - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(6, 6)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(6, 6)) require.NoError(t, err) require.Equal(t, 1.0, count) colID := table.Meta().Columns[0].ID - count, err = statsTbl.GetRowCountByIntColumnRanges(sc, colID, getRange(7, 7)) + count, err = statsTbl.GetRowCountByIntColumnRanges(sctx, colID, getRange(7, 7)) require.NoError(t, err) require.Equal(t, 1.0, count) - count, err = statsTbl.GetRowCountByIntColumnRanges(sc, colID, getRange(6, 6)) + count, err = statsTbl.GetRowCountByIntColumnRanges(sctx, colID, getRange(6, 6)) require.NoError(t, err) require.Equal(t, 1.0, count) } @@ -760,7 +760,7 @@ func TestSmallRangeEstimation(t *testing.T) { table, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) require.NoError(t, err) statsTbl := h.GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() col := statsTbl.Columns[table.Meta().Columns[0].ID] var input []struct { @@ -775,7 +775,7 @@ func TestSmallRangeEstimation(t *testing.T) { statsSuiteData := statistics.GetStatsSuiteData() statsSuiteData.GetTestCases(t, &input, &output) for i, ran := range input { - count, err := col.GetColumnRowCount(sc, getRange(ran.Start, ran.End), statsTbl.Count, false) + count, err := col.GetColumnRowCount(sctx, getRange(ran.Start, ran.End), statsTbl.Count, false) require.NoError(t, err) testdata.OnRecord(func() { output[i].Start = ran.Start diff --git a/statistics/statistics_test.go b/statistics/statistics_test.go index 7ed3869d07ff6..c99802bd0314e 100644 --- a/statistics/statistics_test.go +++ b/statistics/statistics_test.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" @@ -199,13 +198,13 @@ func TestPseudoTable(t *testing.T) { tbl := PseudoTable(ti) require.Equal(t, len(tbl.Columns), 1) require.Greater(t, tbl.Count, int64(0)) - sc := new(stmtctx.StatementContext) - count := tbl.ColumnLessRowCount(sc, types.NewIntDatum(100), colInfo.ID) + sctx := mock.NewContext() + count := tbl.ColumnLessRowCount(sctx, types.NewIntDatum(100), colInfo.ID) require.Equal(t, 3333, int(count)) - count, err := tbl.ColumnEqualRowCount(sc, types.NewIntDatum(1000), colInfo.ID) + count, err := tbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1000), colInfo.ID) require.NoError(t, err) require.Equal(t, 10, int(count)) - count, _ = tbl.ColumnBetweenRowCount(sc, types.NewIntDatum(1000), types.NewIntDatum(5000), colInfo.ID) + count, _ = tbl.ColumnBetweenRowCount(sctx, types.NewIntDatum(1000), types.NewIntDatum(5000), colInfo.ID) require.Equal(t, 250, int(count)) ti.Columns = append(ti.Columns, &model.ColumnInfo{ ID: 2, @@ -258,50 +257,50 @@ func SubTestColumnRange() func(*testing.T) { LowVal: []types.Datum{{}}, HighVal: []types.Datum{types.MaxValueDatum()}, }} - count, err := tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err := tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0] = types.MinNotNullDatum() - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 99900, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].LowExclude = true ran[0].HighVal[0] = types.NewIntDatum(2000) ran[0].HighExclude = true - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowExclude = false ran[0].HighExclude = false - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowVal[0] = ran[0].HighVal[0] - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100, int(count)) tbl.Columns[0] = col ran[0].LowVal[0] = types.Datum{} ran[0].HighVal[0] = types.MaxValueDatum() - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].LowExclude = true ran[0].HighVal[0] = types.NewIntDatum(2000) ran[0].HighExclude = true - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 9998, int(count)) ran[0].LowExclude = false ran[0].HighExclude = false - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 10000, int(count)) ran[0].LowVal[0] = ran[0].HighVal[0] - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) } @@ -312,7 +311,6 @@ func SubTestIntColumnRanges() func(*testing.T) { s := createTestStatisticsSamples(t) bucketCount := int64(256) ctx := mock.NewContext() - sc := ctx.GetSessionVars().StmtCtx s.pk.(*recordSet).cursor = 0 rowCount, hg, err := buildPK(ctx, bucketCount, 0, s.pk) @@ -330,22 +328,22 @@ func SubTestIntColumnRanges() func(*testing.T) { LowVal: []types.Datum{types.NewIntDatum(math.MinInt64)}, HighVal: []types.Datum{types.NewIntDatum(math.MaxInt64)}, }} - count, err := tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err := tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(2000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1000, int(count)) ran[0].LowVal[0].SetInt64(1001) ran[0].HighVal[0].SetInt64(1999) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 998, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(1000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) @@ -353,49 +351,49 @@ func SubTestIntColumnRanges() func(*testing.T) { LowVal: []types.Datum{types.NewUintDatum(0)}, HighVal: []types.Datum{types.NewUintDatum(math.MaxUint64)}, }} - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0].SetUint64(1000) ran[0].HighVal[0].SetUint64(2000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1000, int(count)) ran[0].LowVal[0].SetUint64(1001) ran[0].HighVal[0].SetUint64(1999) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 998, int(count)) ran[0].LowVal[0].SetUint64(1000) ran[0].HighVal[0].SetUint64(1000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) tbl.Columns[0] = col ran[0].LowVal[0].SetInt64(math.MinInt64) ran[0].HighVal[0].SetInt64(math.MaxInt64) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(2000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1001, int(count)) ran[0].LowVal[0].SetInt64(1001) ran[0].HighVal[0].SetInt64(1999) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 999, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(1000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) tbl.Count *= 10 - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) } @@ -406,7 +404,6 @@ func SubTestIndexRanges() func(*testing.T) { s := createTestStatisticsSamples(t) bucketCount := int64(256) ctx := mock.NewContext() - sc := ctx.GetSessionVars().StmtCtx s.rc.(*recordSet).cursor = 0 rowCount, hg, cms, err := buildIndex(ctx, bucketCount, 0, s.rc) @@ -425,51 +422,51 @@ func SubTestIndexRanges() func(*testing.T) { LowVal: []types.Datum{types.MinNotNullDatum()}, HighVal: []types.Datum{types.MaxValueDatum()}, }} - count, err := tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err := tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 99900, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(2000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1001) ran[0].HighVal[0] = types.NewIntDatum(1999) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(1000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100, int(count)) tbl.Indices[0] = &Index{Info: &model.IndexInfo{Columns: []*model.IndexColumn{{Offset: 0}}, Unique: true}} ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(1000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) tbl.Indices[0] = idx ran[0].LowVal[0] = types.MinNotNullDatum() ran[0].HighVal[0] = types.MaxValueDatum() - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(2000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1000, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1001) ran[0].HighVal[0] = types.NewIntDatum(1990) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 989, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(1000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 0, int(count)) } diff --git a/statistics/table.go b/statistics/table.go index 358744716525b..10e08001c7528 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -285,27 +285,28 @@ func (t *Table) IsOutdated() bool { } // ColumnGreaterRowCount estimates the row count where the column greater than value. -func (t *Table) ColumnGreaterRowCount(sc *stmtctx.StatementContext, value types.Datum, colID int64) float64 { +func (t *Table) ColumnGreaterRowCount(sctx sessionctx.Context, value types.Datum, colID int64) float64 { c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoLessRate } return c.greaterRowCount(value) * c.GetIncreaseFactor(t.Count) } // ColumnLessRowCount estimates the row count where the column less than value. Note that null values are not counted. -func (t *Table) ColumnLessRowCount(sc *stmtctx.StatementContext, value types.Datum, colID int64) float64 { +func (t *Table) ColumnLessRowCount(sctx sessionctx.Context, value types.Datum, colID int64) float64 { c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoLessRate } return c.lessRowCount(value) * c.GetIncreaseFactor(t.Count) } // ColumnBetweenRowCount estimates the row count where column greater or equal to a and less than b. -func (t *Table) ColumnBetweenRowCount(sc *stmtctx.StatementContext, a, b types.Datum, colID int64) (float64, error) { +func (t *Table) ColumnBetweenRowCount(sctx sessionctx.Context, a, b types.Datum, colID int64) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoBetweenRate, nil } aEncoded, err := codec.EncodeKey(sc, nil, a) @@ -316,7 +317,7 @@ func (t *Table) ColumnBetweenRowCount(sc *stmtctx.StatementContext, a, b types.D if err != nil { return 0, err } - count := c.BetweenRowCount(sc, a, b, aEncoded, bEncoded) + count := c.BetweenRowCount(sctx, a, b, aEncoded, bEncoded) if a.IsNull() { count += float64(c.NullCount) } @@ -324,25 +325,26 @@ func (t *Table) ColumnBetweenRowCount(sc *stmtctx.StatementContext, a, b types.D } // ColumnEqualRowCount estimates the row count where the column equals to value. -func (t *Table) ColumnEqualRowCount(sc *stmtctx.StatementContext, value types.Datum, colID int64) (float64, error) { +func (t *Table) ColumnEqualRowCount(sctx sessionctx.Context, value types.Datum, colID int64) (float64, error) { c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoEqualRate, nil } - encodedVal, err := codec.EncodeKey(sc, nil, value) + encodedVal, err := codec.EncodeKey(sctx.GetSessionVars().StmtCtx, nil, value) if err != nil { return 0, err } - result, err := c.equalRowCount(sc, value, encodedVal, t.ModifyCount) + result, err := c.equalRowCount(sctx, value, encodedVal, t.ModifyCount) result *= c.GetIncreaseFactor(t.Count) return result, errors.Trace(err) } // GetRowCountByIntColumnRanges estimates the row count by a slice of IntColumnRange. -func (coll *HistColl) GetRowCountByIntColumnRanges(sc *stmtctx.StatementContext, colID int64, intRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) GetRowCountByIntColumnRanges(sctx sessionctx.Context, colID int64, intRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx var result float64 c, ok := coll.Columns[colID] - if !ok || c.IsInvalid(sc, coll.Pseudo) { + if !ok || c.IsInvalid(sctx, coll.Pseudo) { if len(intRanges) == 0 { return 0, nil } @@ -352,36 +354,38 @@ func (coll *HistColl) GetRowCountByIntColumnRanges(sc *stmtctx.StatementContext, result = getPseudoRowCountByUnsignedIntRanges(intRanges, float64(coll.Count)) } if sc.EnableOptimizerCETrace && ok { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats-Pseudo", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats-Pseudo", uint64(result)) } return result, nil } - result, err := c.GetColumnRowCount(sc, intRanges, coll.Count, true) + result, err := c.GetColumnRowCount(sctx, intRanges, coll.Count, true) if sc.EnableOptimizerCETrace { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats", uint64(result)) } return result, errors.Trace(err) } // GetRowCountByColumnRanges estimates the row count by a slice of Range. -func (coll *HistColl) GetRowCountByColumnRanges(sc *stmtctx.StatementContext, colID int64, colRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) GetRowCountByColumnRanges(sctx sessionctx.Context, colID int64, colRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx c, ok := coll.Columns[colID] - if !ok || c.IsInvalid(sc, coll.Pseudo) { + if !ok || c.IsInvalid(sctx, coll.Pseudo) { result, err := GetPseudoRowCountByColumnRanges(sc, float64(coll.Count), colRanges, 0) if err == nil && sc.EnableOptimizerCETrace && ok { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats-Pseudo", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats-Pseudo", uint64(result)) } return result, err } - result, err := c.GetColumnRowCount(sc, colRanges, coll.Count, false) + result, err := c.GetColumnRowCount(sctx, colRanges, coll.Count, false) if sc.EnableOptimizerCETrace { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats", uint64(result)) } return result, errors.Trace(err) } // GetRowCountByIndexRanges estimates the row count by a slice of Range. -func (coll *HistColl) GetRowCountByIndexRanges(sc *stmtctx.StatementContext, idxID int64, indexRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) GetRowCountByIndexRanges(sctx sessionctx.Context, idxID int64, indexRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx idx, ok := coll.Indices[idxID] colNames := make([]string, 0, 8) if ok { @@ -396,28 +400,29 @@ func (coll *HistColl) GetRowCountByIndexRanges(sc *stmtctx.StatementContext, idx } result, err := getPseudoRowCountByIndexRanges(sc, indexRanges, float64(coll.Count), colsLen) if err == nil && sc.EnableOptimizerCETrace && ok { - CETraceRange(sc, coll.PhysicalID, colNames, indexRanges, "Index Stats-Pseudo", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats-Pseudo", uint64(result)) } return result, err } var result float64 var err error if idx.CMSketch != nil && idx.StatsVer == Version1 { - result, err = coll.getIndexRowCount(sc, idxID, indexRanges) + result, err = coll.getIndexRowCount(sctx, idxID, indexRanges) } else { - result, err = idx.GetRowCount(sc, coll, indexRanges, coll.Count) + result, err = idx.GetRowCount(sctx, coll, indexRanges, coll.Count) } if sc.EnableOptimizerCETrace { - CETraceRange(sc, coll.PhysicalID, colNames, indexRanges, "Index Stats", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats", uint64(result)) } return result, errors.Trace(err) } // CETraceRange appends a list of ranges and related information into CE trace -func CETraceRange(sc *stmtctx.StatementContext, tableID int64, colNames []string, ranges []*ranger.Range, tp string, rowCount uint64) { +func CETraceRange(sctx sessionctx.Context, tableID int64, colNames []string, ranges []*ranger.Range, tp string, rowCount uint64) { + sc := sctx.GetSessionVars().StmtCtx allPoint := true for _, ran := range ranges { - if !ran.IsPointNullable(sc) { + if !ran.IsPointNullable(sctx) { allPoint = false break } @@ -572,7 +577,7 @@ func outOfRangeEQSelectivity(ndv, realtimeRowCount, columnRowCount int64) float6 } // crossValidationSelectivity gets the selectivity of multi-column equal conditions by cross validation. -func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, idx *Index, usedColsLen int, idxPointRange *ranger.Range) (float64, float64, error) { +func (coll *HistColl) crossValidationSelectivity(sctx sessionctx.Context, idx *Index, usedColsLen int, idxPointRange *ranger.Range) (float64, float64, error) { minRowCount := math.MaxFloat64 cols := coll.Idx2ColumnIDs[idx.ID] crossValidationSelectivity := 1.0 @@ -582,7 +587,7 @@ func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, i break } if col, ok := coll.Columns[colID]; ok { - if col.IsInvalid(sc, coll.Pseudo) { + if col.IsInvalid(sctx, coll.Pseudo) { continue } lowExclude := idxPointRange.LowExclude @@ -604,7 +609,7 @@ func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, i HighExclude: highExclude, } - rowCount, err := col.GetColumnRowCount(sc, []*ranger.Range{&rang}, coll.Count, col.IsHandle) + rowCount, err := col.GetColumnRowCount(sctx, []*ranger.Range{&rang}, coll.Count, col.IsHandle) if err != nil { return 0, 0, err } @@ -619,7 +624,7 @@ func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, i } // getEqualCondSelectivity gets the selectivity of the equal conditions. -func (coll *HistColl) getEqualCondSelectivity(sc *stmtctx.StatementContext, idx *Index, bytes []byte, usedColsLen int, idxPointRange *ranger.Range) (float64, error) { +func (coll *HistColl) getEqualCondSelectivity(sctx sessionctx.Context, idx *Index, bytes []byte, usedColsLen int, idxPointRange *ranger.Range) (float64, error) { coverAll := len(idx.Info.Columns) == usedColsLen // In this case, the row count is at most 1. if idx.Info.Unique && coverAll { @@ -646,7 +651,7 @@ func (coll *HistColl) getEqualCondSelectivity(sc *stmtctx.StatementContext, idx return outOfRangeEQSelectivity(ndv, coll.Count, int64(idx.TotalRowCount())), nil } - minRowCount, crossValidationSelectivity, err := coll.crossValidationSelectivity(sc, idx, usedColsLen, idxPointRange) + minRowCount, crossValidationSelectivity, err := coll.crossValidationSelectivity(sctx, idx, usedColsLen, idxPointRange) if err != nil { return 0, nil } @@ -658,7 +663,8 @@ func (coll *HistColl) getEqualCondSelectivity(sc *stmtctx.StatementContext, idx return idxCount / idx.TotalRowCount(), nil } -func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64, indexRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) getIndexRowCount(sctx sessionctx.Context, idxID int64, indexRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx idx := coll.Indices[idxID] totalCount := float64(0) for _, ran := range indexRanges { @@ -675,7 +681,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 // on single-column index, use previous way as well, because CMSketch does not contain null // values in this case. if rangePosition == 0 || isSingleColIdxNullRange(idx, ran) { - count, err := idx.GetRowCount(sc, nil, []*ranger.Range{ran}, coll.Count) + count, err := idx.GetRowCount(sctx, nil, []*ranger.Range{ran}, coll.Count) if err != nil { return 0, errors.Trace(err) } @@ -689,7 +695,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 if err != nil { return 0, errors.Trace(err) } - selectivity, err = coll.getEqualCondSelectivity(sc, idx, bytes, rangePosition, ran) + selectivity, err = coll.getEqualCondSelectivity(sctx, idx, bytes, rangePosition, ran) if err != nil { return 0, errors.Trace(err) } @@ -705,7 +711,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 if err != nil { return 0, err } - res, err := coll.getEqualCondSelectivity(sc, idx, bytes, rangePosition, ran) + res, err := coll.getEqualCondSelectivity(sctx, idx, bytes, rangePosition, ran) if err != nil { return 0, errors.Trace(err) } @@ -731,9 +737,9 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 } // prefer index stats over column stats if idx, ok := coll.ColID2IdxID[colID]; ok { - count, err = coll.GetRowCountByIndexRanges(sc, idx, []*ranger.Range{&rang}) + count, err = coll.GetRowCountByIndexRanges(sctx, idx, []*ranger.Range{&rang}) } else { - count, err = coll.GetRowCountByColumnRanges(sc, colID, []*ranger.Range{&rang}) + count, err = coll.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{&rang}) } if err != nil { return 0, errors.Trace(err) diff --git a/util/ranger/types.go b/util/ranger/types.go index 2e8cc1dc6120d..f2bf561f6a3cf 100644 --- a/util/ranger/types.go +++ b/util/ranger/types.go @@ -119,8 +119,8 @@ func (ran *Range) IsPointNonNullable(sctx sessionctx.Context) bool { // IsPointNullable returns if the range is a point. // TODO: unify the parameter type with IsPointNullable and IsPoint -func (ran *Range) IsPointNullable(stmtCtx *stmtctx.StatementContext) bool { - return ran.isPoint(stmtCtx, true) +func (ran *Range) IsPointNullable(sctx sessionctx.Context) bool { + return ran.isPoint(sctx.GetSessionVars().StmtCtx, true) } // IsFullRange check if the range is full scan range From 5d463f3be7d2aaa30d17ab89c8f0dcece8bfafd4 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 14 Dec 2021 19:48:34 +0800 Subject: [PATCH 4/8] metrics: fix the Max SafeTS Gap metrics (#30689) --- metrics/grafana/tidb.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metrics/grafana/tidb.json b/metrics/grafana/tidb.json index 2776e7625570f..518112d04abcc 100644 --- a/metrics/grafana/tidb.json +++ b/metrics/grafana/tidb.json @@ -5426,7 +5426,7 @@ "steppedLine": false, "targets": [ { - "expr": "sum(tidb_tikvclient_safets_gap_seconds{tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\"}) by (instance, store)", + "expr": "tidb_tikvclient_min_safets_gap_seconds{tidb_cluster=\"$tidb_cluster\"}", "format": "time_series", "intervalFactor": 2, "legendFormat": "{{instance}}-store-{{store}}", @@ -5438,7 +5438,7 @@ "timeFrom": null, "timeRegions": [], "timeShift": null, - "title": "Max SafeTS gap", + "title": "Max SafeTS Gap", "tooltip": { "msResolution": false, "shared": true, From 4b48e55ae9c7d2580233f846858a16338cf89b77 Mon Sep 17 00:00:00 2001 From: glorv Date: Tue, 14 Dec 2021 20:04:35 +0800 Subject: [PATCH 5/8] lightning: Add source dir existence check for s3 (#30674) --- br/pkg/backup/client_test.go | 2 -- br/pkg/lightning/lightning.go | 13 +++++++++++++ br/pkg/storage/gcs.go | 8 -------- br/pkg/storage/s3.go | 8 -------- br/pkg/storage/s3_test.go | 3 +-- br/pkg/storage/storage.go | 13 ------------- br/pkg/task/backup.go | 1 - br/pkg/task/backup_raw.go | 1 - br/pkg/task/common.go | 4 +++- br/pkg/task/restore.go | 1 - br/tests/lightning_s3/run.sh | 14 ++++++++++++++ 11 files changed, 31 insertions(+), 37 deletions(-) diff --git a/br/pkg/backup/client_test.go b/br/pkg/backup/client_test.go index e46d832bae3ee..e341f15417f55 100644 --- a/br/pkg/backup/client_test.go +++ b/br/pkg/backup/client_test.go @@ -265,7 +265,6 @@ func (r *testBackup) TestSendCreds(c *C) { c.Assert(err, IsNil) opts := &storage.ExternalStorageOptions{ SendCredentials: true, - SkipCheckPath: true, } _, err = storage.New(r.ctx, backend, opts) c.Assert(err, IsNil) @@ -284,7 +283,6 @@ func (r *testBackup) TestSendCreds(c *C) { c.Assert(err, IsNil) opts = &storage.ExternalStorageOptions{ SendCredentials: false, - SkipCheckPath: true, } _, err = storage.New(r.ctx, backend, opts) c.Assert(err, IsNil) diff --git a/br/pkg/lightning/lightning.go b/br/pkg/lightning/lightning.go index 575b661be6cac..9fc40cdf77144 100644 --- a/br/pkg/lightning/lightning.go +++ b/br/pkg/lightning/lightning.go @@ -296,6 +296,19 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, g glue. return errors.Annotate(err, "create storage failed") } + // return expectedErr means at least meet one file + expectedErr := errors.New("Stop Iter") + walkErr := s.WalkDir(ctx, &storage.WalkOption{ListCount: 1}, func(string, int64) error { + // return an error when meet the first regular file to break the walk loop + return expectedErr + }) + if !errors.ErrorEqual(walkErr, expectedErr) { + if walkErr == nil { + return errors.Errorf("data-source-dir '%s' doesn't exist or contains no files", taskCfg.Mydumper.SourceDir) + } + return errors.Annotatef(walkErr, "visit data-source-dir '%s' failed", taskCfg.Mydumper.SourceDir) + } + loadTask := log.L().Begin(zap.InfoLevel, "load data source") var mdl *mydump.MDLoader mdl, err = mydump.NewMyDumpLoaderWithStore(ctx, taskCfg, s) diff --git a/br/pkg/storage/gcs.go b/br/pkg/storage/gcs.go index e4835e0eb6111..c54141b8ee560 100644 --- a/br/pkg/storage/gcs.go +++ b/br/pkg/storage/gcs.go @@ -276,14 +276,6 @@ func newGCSStorage(ctx context.Context, gcs *backuppb.GCS, opts *ExternalStorage // so we need find sst in slash directory gcs.Prefix += "//" } - // TODO remove it after BR remove cfg skip-check-path - if !opts.SkipCheckPath { - // check bucket exists - _, err = bucket.Attrs(ctx) - if err != nil { - return nil, errors.Annotatef(err, "gcs://%s/%s", gcs.Bucket, gcs.Prefix) - } - } return &gcsStorage{gcs: gcs, bucket: bucket}, nil } diff --git a/br/pkg/storage/s3.go b/br/pkg/storage/s3.go index 2c07b5af2cad0..6accafee7363d 100644 --- a/br/pkg/storage/s3.go +++ b/br/pkg/storage/s3.go @@ -283,14 +283,6 @@ func newS3Storage(backend *backuppb.S3, opts *ExternalStorageOptions) (*S3Storag } c := s3.New(ses) - // TODO remove it after BR remove cfg skip-check-path - if !opts.SkipCheckPath { - err = checkS3Bucket(c, &qs) - if err != nil { - return nil, errors.Annotatef(berrors.ErrStorageInvalidConfig, "Bucket %s is not accessible: %v", qs.Bucket, err) - } - } - if len(qs.Prefix) > 0 && !strings.HasSuffix(qs.Prefix, "/") { qs.Prefix += "/" } diff --git a/br/pkg/storage/s3_test.go b/br/pkg/storage/s3_test.go index 413f5e8881da1..cf30828b07c65 100644 --- a/br/pkg/storage/s3_test.go +++ b/br/pkg/storage/s3_test.go @@ -288,7 +288,6 @@ func (s *s3Suite) TestS3Storage(c *C) { _, err := New(ctx, s3, &ExternalStorageOptions{ SendCredentials: test.sendCredential, CheckPermissions: test.hackPermission, - SkipCheckPath: true, }) if test.errReturn { c.Assert(err, NotNil) @@ -414,7 +413,7 @@ func (s *s3Suite) TestS3Storage(c *C) { func (s *s3Suite) TestS3URI(c *C) { backend, err := ParseBackend("s3://bucket/prefix/", nil) c.Assert(err, IsNil) - storage, err := New(context.Background(), backend, &ExternalStorageOptions{SkipCheckPath: true}) + storage, err := New(context.Background(), backend, &ExternalStorageOptions{}) c.Assert(err, IsNil) c.Assert(storage.URI(), Equals, "s3://bucket/prefix/") } diff --git a/br/pkg/storage/storage.go b/br/pkg/storage/storage.go index af05abac398fa..177656fc378a0 100644 --- a/br/pkg/storage/storage.go +++ b/br/pkg/storage/storage.go @@ -121,18 +121,6 @@ type ExternalStorageOptions struct { // NoCredentials means that no cloud credentials are supplied to BR NoCredentials bool - // SkipCheckPath marks whether to skip checking path's existence. - // - // This should only be set to true in testing, to avoid interacting with the - // real world. - // When this field is false (i.e. path checking is enabled), the New() - // function will ensure the path referred by the backend exists by - // recursively creating the folders. This will also throw an error if such - // operation is impossible (e.g. when the bucket storing the path is missing). - - // deprecated: use checkPermissions and specify the checkPermission instead. - SkipCheckPath bool - // HTTPClient to use. The created storage may ignore this field if it is not // directly using HTTP (e.g. the local storage). HTTPClient *http.Client @@ -148,7 +136,6 @@ type ExternalStorageOptions struct { func Create(ctx context.Context, backend *backuppb.StorageBackend, sendCreds bool) (ExternalStorage, error) { return New(ctx, backend, &ExternalStorageOptions{ SendCredentials: sendCreds, - SkipCheckPath: false, HTTPClient: nil, }) } diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index 7a9037c20f80c..87461f53bab74 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -257,7 +257,6 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig opts := storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } if err = client.SetStorage(ctx, u, &opts); err != nil { return errors.Trace(err) diff --git a/br/pkg/task/backup_raw.go b/br/pkg/task/backup_raw.go index d8d11ea95c3a1..febe151218706 100644 --- a/br/pkg/task/backup_raw.go +++ b/br/pkg/task/backup_raw.go @@ -150,7 +150,6 @@ func RunBackupRaw(c context.Context, g glue.Glue, cmdName string, cfg *RawKvConf opts := storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } if err = client.SetStorage(ctx, u, &opts); err != nil { return errors.Trace(err) diff --git a/br/pkg/task/common.go b/br/pkg/task/common.go index 4ae54f03cde5a..357c7d267e449 100644 --- a/br/pkg/task/common.go +++ b/br/pkg/task/common.go @@ -485,6 +485,9 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { if cfg.SkipCheckPath, err = flags.GetBool(flagSkipCheckPath); err != nil { return errors.Trace(err) } + if cfg.SkipCheckPath { + log.L().Info("--skip-check-path is deprecated, need explicitly set it anymore") + } if err = cfg.parseCipherInfo(flags); err != nil { return errors.Trace(err) @@ -548,7 +551,6 @@ func storageOpts(cfg *Config) *storage.ExternalStorageOptions { return &storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } } diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index a80549d005905..ae46f15b1f6ce 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -249,7 +249,6 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf opts := storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } if err = client.SetStorage(ctx, u, &opts); err != nil { return errors.Trace(err) diff --git a/br/tests/lightning_s3/run.sh b/br/tests/lightning_s3/run.sh index 6fed0af2b81da..5b2973784fd7e 100755 --- a/br/tests/lightning_s3/run.sh +++ b/br/tests/lightning_s3/run.sh @@ -62,6 +62,20 @@ _EOF_ run_sql "DROP DATABASE IF EXISTS $DB;" run_sql "DROP TABLE IF EXISTS $DB.$TABLE;" +# test not exist path +rm -f $TEST_DIR/lightning.log +SOURCE_DIR="s3://$BUCKET/not-exist-path?endpoint=http%3A//127.0.0.1%3A9900&access_key=$MINIO_ACCESS_KEY&secret_access_key=$MINIO_SECRET_KEY&force_path_style=true" +! run_lightning -d $SOURCE_DIR --backend local 2> /dev/null +grep -Eq "data-source-dir .* doesn't exist or contains no files" $TEST_DIR/lightning.log + +# test empty dir +rm -f $TEST_DIR/lightning.log +emptyPath=empty-bucket/empty-path +mkdir -p $DBPATH/$emptyPath +SOURCE_DIR="s3://$emptyPath/not-exist-path?endpoint=http%3A//127.0.0.1%3A9900&access_key=$MINIO_ACCESS_KEY&secret_access_key=$MINIO_SECRET_KEY&force_path_style=true" +! run_lightning -d $SOURCE_DIR --backend local 2> /dev/null +grep -Eq "data-source-dir .* doesn't exist or contains no files" $TEST_DIR/lightning.log + SOURCE_DIR="s3://$BUCKET/?endpoint=http%3A//127.0.0.1%3A9900&access_key=$MINIO_ACCESS_KEY&secret_access_key=$MINIO_SECRET_KEY&force_path_style=true" run_lightning -d $SOURCE_DIR --backend local 2> /dev/null run_sql "SELECT count(*), sum(i) FROM \`$DB\`.$TABLE" From 950a274afc8b54fc081b77e4d8336760a974bcf4 Mon Sep 17 00:00:00 2001 From: wangggong <793160615@qq.com> Date: Tue, 14 Dec 2021 21:10:35 +0800 Subject: [PATCH 6/8] golangci-lint: support durationcheck (#30027) --- .golangci.yml | 1 + server/util.go | 8 ++++---- session/session.go | 3 ++- types/time.go | 22 +++++++++++----------- util/execdetails/execdetails.go | 2 +- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 816e2404a9e36..d262ed0e0457b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -22,6 +22,7 @@ linters: - rowserrcheck - unconvert - makezero + - durationcheck linters-settings: staticcheck: diff --git a/server/util.go b/server/util.go index 2a45dce017dc2..f5518f0396a9a 100644 --- a/server/util.go +++ b/server/util.go @@ -179,16 +179,16 @@ func dumpBinaryTime(dur time.Duration) (data []byte) { dur = -dur } days := dur / (24 * time.Hour) - dur -= days * 24 * time.Hour + dur -= days * 24 * time.Hour //nolint:durationcheck data[2] = byte(days) hours := dur / time.Hour - dur -= hours * time.Hour + dur -= hours * time.Hour //nolint:durationcheck data[6] = byte(hours) minutes := dur / time.Minute - dur -= minutes * time.Minute + dur -= minutes * time.Minute //nolint:durationcheck data[7] = byte(minutes) seconds := dur / time.Second - dur -= seconds * time.Second + dur -= seconds * time.Second //nolint:durationcheck data[8] = byte(seconds) if dur == 0 { data[0] = 8 diff --git a/session/session.go b/session/session.go index 8a2b61e50e6db..bfc5288a7ff4f 100644 --- a/session/session.go +++ b/session/session.go @@ -1549,7 +1549,8 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex failpoint.Inject("mockStmtSlow", func(val failpoint.Value) { if strings.Contains(stmtNode.Text(), "/* sleep */") { - time.Sleep(time.Duration(val.(int)) * time.Millisecond) + v, _ := val.(int) + time.Sleep(time.Duration(v) * time.Millisecond) } }) diff --git a/types/time.go b/types/time.go index 8d2a11fc58d90..0ea01c8a85f96 100644 --- a/types/time.go +++ b/types/time.go @@ -49,9 +49,9 @@ const ( // MaxDuration is the maximum for duration. MaxDuration int64 = 838*10000 + 59*100 + 59 // MinTime is the minimum for mysql time type. - MinTime = -gotime.Duration(838*3600+59*60+59) * gotime.Second + MinTime = -(838*gotime.Hour + 59*gotime.Minute + 59*gotime.Second) // MaxTime is the maximum for mysql time type. - MaxTime = gotime.Duration(838*3600+59*60+59) * gotime.Second + MaxTime = 838*gotime.Hour + 59*gotime.Minute + 59*gotime.Second // ZeroDatetimeStr is the string representation of a zero datetime. ZeroDatetimeStr = "0000-00-00 00:00:00" // ZeroDateStr is the string representation of a zero date. @@ -466,7 +466,7 @@ func (t Time) ConvertToDuration() (Duration, error) { hour, minute, second := t.Clock() frac := t.Microsecond() * 1000 - d := gotime.Duration(hour*3600+minute*60+second)*gotime.Second + gotime.Duration(frac) + d := gotime.Duration(hour*3600+minute*60+second)*gotime.Second + gotime.Duration(frac) //nolint:durationcheck // TODO: check convert validation return Duration{Duration: d, Fsp: t.Fsp()}, nil } @@ -579,7 +579,7 @@ func RoundFrac(t gotime.Time, fsp int8) (gotime.Time, error) { if err != nil { return t, errors.Trace(err) } - return t.Round(gotime.Duration(math.Pow10(9-int(fsp))) * gotime.Nanosecond), nil + return t.Round(gotime.Duration(math.Pow10(9-int(fsp))) * gotime.Nanosecond), nil //nolint:durationcheck } // TruncateFrac truncates fractional seconds precision with new fsp and returns a new one. @@ -589,7 +589,7 @@ func TruncateFrac(t gotime.Time, fsp int8) (gotime.Time, error) { if _, err := CheckFsp(int(fsp)); err != nil { return t, err } - return t.Truncate(gotime.Duration(math.Pow10(9-int(fsp))) * gotime.Nanosecond), nil + return t.Truncate(gotime.Duration(math.Pow10(9-int(fsp))) * gotime.Nanosecond), nil //nolint:durationcheck } // ToPackedUint encodes Time to a packed uint64 value. @@ -1270,7 +1270,7 @@ func AdjustYear(y int64, adjustZero bool) (int64, error) { // NewDuration construct duration with time. func NewDuration(hour, minute, second, microsecond int, fsp int8) Duration { return Duration{ - Duration: gotime.Duration(hour)*gotime.Hour + gotime.Duration(minute)*gotime.Minute + gotime.Duration(second)*gotime.Second + gotime.Duration(microsecond)*gotime.Microsecond, + Duration: gotime.Duration(hour)*gotime.Hour + gotime.Duration(minute)*gotime.Minute + gotime.Duration(second)*gotime.Second + gotime.Duration(microsecond)*gotime.Microsecond, //nolint:durationcheck Fsp: fsp, } } @@ -1490,7 +1490,7 @@ func (d Duration) RoundFrac(fsp int8, loc *gotime.Location) (Duration, error) { } n := gotime.Date(0, 0, 0, 0, 0, 0, 0, tz) - nd := n.Add(d.Duration).Round(gotime.Duration(math.Pow10(9-int(fsp))) * gotime.Nanosecond).Sub(n) + nd := n.Add(d.Duration).Round(gotime.Duration(math.Pow10(9-int(fsp))) * gotime.Nanosecond).Sub(n) //nolint:durationcheck return Duration{Duration: nd, Fsp: fsp}, nil } @@ -1711,7 +1711,7 @@ func matchDuration(str string, fsp int8) (Duration, error) { return Duration{t, fsp}, ErrTruncatedWrongVal.GenWithStackByArgs("time", str) } - d := gotime.Duration(hhmmss[0]*3600+hhmmss[1]*60+hhmmss[2])*gotime.Second + gotime.Duration(frac)*gotime.Microsecond + d := gotime.Duration(hhmmss[0]*3600+hhmmss[1]*60+hhmmss[2])*gotime.Second + gotime.Duration(frac)*gotime.Microsecond //nolint:durationcheck if negative { d = -d } @@ -1800,11 +1800,11 @@ func splitDuration(t gotime.Duration) (int, int, int, int, int) { } hours := t / gotime.Hour - t -= hours * gotime.Hour + t -= hours * gotime.Hour //nolint:durationcheck minutes := t / gotime.Minute - t -= minutes * gotime.Minute + t -= minutes * gotime.Minute //nolint:durationcheck seconds := t / gotime.Second - t -= seconds * gotime.Second + t -= seconds * gotime.Second //nolint:durationcheck fraction := t / gotime.Microsecond return sign, int(hours), int(minutes), int(seconds), int(fraction) diff --git a/util/execdetails/execdetails.go b/util/execdetails/execdetails.go index 4265145c2d66a..ea1e4cda1e746 100644 --- a/util/execdetails/execdetails.go +++ b/util/execdetails/execdetails.go @@ -968,7 +968,7 @@ func FormatDuration(d time.Duration) string { if unit == time.Nanosecond { return d.String() } - integer := (d / unit) * unit + integer := (d / unit) * unit //nolint:durationcheck decimal := float64(d%unit) / float64(unit) if d < 10*unit { decimal = math.Round(decimal*100) / 100 From 6c0fcea070af6338dc988fad725911637cea02c2 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Tue, 14 Dec 2021 21:46:35 +0800 Subject: [PATCH 7/8] executor: fix data race on IndexHashJoin.cancelFunc (#30701) --- executor/index_lookup_hash_join.go | 2 +- executor/index_lookup_join.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/executor/index_lookup_hash_join.go b/executor/index_lookup_hash_join.go index 75d84c2162480..0beb3e59e66b1 100644 --- a/executor/index_lookup_hash_join.go +++ b/executor/index_lookup_hash_join.go @@ -148,6 +148,7 @@ func (e *IndexNestedLoopHashJoin) Open(ctx context.Context) error { } e.memTracker = memory.NewTracker(e.id, -1) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + e.cancelFunc = nil e.innerPtrBytes = make([][]byte, 0, 8) if e.runtimeStats != nil { e.stats = &indexLookUpJoinRuntimeStats{} @@ -311,7 +312,6 @@ func (e *IndexNestedLoopHashJoin) isDryUpTasks(ctx context.Context) bool { func (e *IndexNestedLoopHashJoin) Close() error { if e.cancelFunc != nil { e.cancelFunc() - e.cancelFunc = nil } if e.resultCh != nil { for range e.resultCh { diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index ff5d317e2bb5b..4be2f24272ae4 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -174,6 +174,7 @@ func (e *IndexLookUpJoin) Open(ctx context.Context) error { e.stats = &indexLookUpJoinRuntimeStats{} e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats) } + e.cancelFunc = nil e.startWorkers(ctx) return nil } From 813f6efd41ce3276e5737efa5734c1795048477a Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Tue, 14 Dec 2021 18:40:35 -0700 Subject: [PATCH 8/8] sessionctx/variable: change tidb_store_limit to global only (#30522) --- executor/set_test.go | 11 +++-------- sessionctx/variable/sysvar.go | 5 +---- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/executor/set_test.go b/executor/set_test.go index 9be1f1794ce1c..6b166059e6921 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -366,16 +366,11 @@ func (s *testSerialSuite1) TestSetVar(c *C) { tk.MustExec("set @@tidb_expensive_query_time_threshold=70") tk.MustQuery("select @@tidb_expensive_query_time_threshold;").Check(testkit.Rows("70")) - tk.MustQuery("select @@tidb_store_limit;").Check(testkit.Rows("0")) - tk.MustExec("set @@tidb_store_limit = 100") - tk.MustQuery("select @@tidb_store_limit;").Check(testkit.Rows("100")) - tk.MustQuery("select @@session.tidb_store_limit;").Check(testkit.Rows("100")) tk.MustQuery("select @@global.tidb_store_limit;").Check(testkit.Rows("0")) - tk.MustExec("set @@tidb_store_limit = 0") - + tk.MustExec("set @@global.tidb_store_limit = 100") + tk.MustQuery("select @@global.tidb_store_limit;").Check(testkit.Rows("100")) + tk.MustExec("set @@global.tidb_store_limit = 0") tk.MustExec("set global tidb_store_limit = 100") - tk.MustQuery("select @@tidb_store_limit;").Check(testkit.Rows("100")) - tk.MustQuery("select @@session.tidb_store_limit;").Check(testkit.Rows("100")) tk.MustQuery("select @@global.tidb_store_limit;").Check(testkit.Rows("100")) tk.MustQuery("select @@session.tidb_metric_query_step;").Check(testkit.Rows("60")) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 60543c00d334e..fc7ce09cae6a7 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -1036,10 +1036,7 @@ var defaultSysVars = []*SysVar{ } return nil }}, - {Scope: ScopeGlobal | ScopeSession, Name: TiDBStoreLimit, Value: strconv.FormatInt(atomic.LoadInt64(&config.GetGlobalConfig().TiKVClient.StoreLimit), 10), Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64, SetSession: func(s *SessionVars, val string) error { - tikvstore.StoreLimit.Store(tidbOptInt64(val, DefTiDBStoreLimit)) - return nil - }, GetSession: func(s *SessionVars) (string, error) { + {Scope: ScopeGlobal, Name: TiDBStoreLimit, Value: strconv.FormatInt(atomic.LoadInt64(&config.GetGlobalConfig().TiKVClient.StoreLimit), 10), Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64, GetGlobal: func(s *SessionVars) (string, error) { return strconv.FormatInt(tikvstore.StoreLimit.Load(), 10), nil }, SetGlobal: func(s *SessionVars, val string) error { tikvstore.StoreLimit.Store(tidbOptInt64(val, DefTiDBStoreLimit))