diff --git a/dumpling/export/BUILD.bazel b/dumpling/export/BUILD.bazel index 08cc7fe2e664e..fc3006cb855c0 100644 --- a/dumpling/export/BUILD.bazel +++ b/dumpling/export/BUILD.bazel @@ -106,6 +106,7 @@ go_test( "@com_github_data_dog_go_sqlmock//:go-sqlmock", "@com_github_go_sql_driver_mysql//:mysql", "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_prometheus_client_golang//prometheus/collectors", "@com_github_stretchr_testify//require", "@org_golang_x_sync//errgroup", diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index 7e5a81e0f3ce1..b5f5a00af1974 100644 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -1531,7 +1531,7 @@ func setSessionParam(d *Dumper) error { d.L().Info("cannot check whether TiDB has TiKV, will apply tidb_snapshot by default. This won't affect dump process", log.ShortError(err)) } if conf.ServerInfo.HasTiKV { - sessionParam["tidb_snapshot"] = snapshot + sessionParam[snapshotVar] = snapshot } } } diff --git a/dumpling/export/dump_test.go b/dumpling/export/dump_test.go index c9a40bba28d6f..7cbc52e341451 100644 --- a/dumpling/export/dump_test.go +++ b/dumpling/export/dump_test.go @@ -9,7 +9,9 @@ import ( "time" "github.com/DATA-DOG/go-sqlmock" + "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/version" tcontext "github.com/pingcap/tidb/dumpling/context" "github.com/pingcap/tidb/parser" @@ -224,3 +226,64 @@ func TestUnregisterMetrics(t *testing.T) { // should not panic require.Error(t, err) } + +func TestSetSessionParams(t *testing.T) { + // case 1: fail to set tidb_snapshot, should return error with hint + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + mock.ExpectQuery("SELECT @@tidb_config"). + WillReturnError(errors.New("mock error")) + mock.ExpectQuery("SELECT COUNT\\(1\\) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'"). + WillReturnError(errors.New("mock error")) + tikvErr := &mysql.MySQLError{ + Number: 1105, + Message: "can not get 'tikv_gc_safe_point'", + } + mock.ExpectExec("SET SESSION tidb_snapshot"). + WillReturnError(tikvErr) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/dumpling/export/SkipResetDB", "return(true)")) + defer failpoint.Disable("github.com/pingcap/tidb/dumpling/export/SkipResetDB=return(true)") + + tctx, cancel := tcontext.Background().WithLogger(appLogger).WithCancel() + defer cancel() + + conf := DefaultConfig() + conf.ServerInfo = version.ServerInfo{ + ServerType: version.ServerTypeTiDB, + HasTiKV: false, + } + conf.Snapshot = "439153276059648000" + conf.Consistency = ConsistencyTypeSnapshot + d := &Dumper{ + tctx: tctx, + conf: conf, + cancelCtx: cancel, + dbHandle: db, + } + err = setSessionParam(d) + require.ErrorContains(t, err, "consistency=none") + + // case 2: fail to set other + conf.ServerInfo = version.ServerInfo{ + ServerType: version.ServerTypeMySQL, + HasTiKV: false, + } + conf.Snapshot = "" + conf.Consistency = ConsistencyTypeFlush + conf.SessionParams = map[string]interface{}{ + "mock": "UTC", + } + d.dbHandle = db + mock.ExpectExec("SET SESSION mock"). + WillReturnError(errors.New("Unknown system variable mock")) + mock.ExpectClose() + mock.ExpectClose() + + err = setSessionParam(d) + require.NoError(t, err) +} diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 837bec568b9a7..60d14ac49e14c 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -29,6 +29,7 @@ import ( const ( orderByTiDBRowID = "ORDER BY `_tidb_rowid`" + snapshotVar = "tidb_snapshot" ) type listTableType int @@ -851,7 +852,9 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Con s := fmt.Sprintf("SET SESSION %s = ?", k) _, err := db.ExecContext(tctx, s, pv) if err != nil { - if isUnknownSystemVariableErr(err) { + if k == snapshotVar { + err = errors.Annotate(err, "fail to set snapshot for tidb, please set --consistency=none/--consistency=lock or fix snapshot problem") + } else if isUnknownSystemVariableErr(err) { tctx.L().Info("session variable is not supported by db", zap.String("variable", k), zap.Reflect("value", v)) continue } @@ -876,6 +879,9 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Con } cfg.Params[k] = s } + failpoint.Inject("SkipResetDB", func(_ failpoint.Value) { + failpoint.Return(db, nil) + }) db.Close() c, err := mysql.NewConnector(cfg) diff --git a/executor/adapter.go b/executor/adapter.go index 5e12cce1ccc69..61a10f9807b74 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -734,12 +734,7 @@ func (a *ExecStmt) handleNoDelay(ctx context.Context, e Executor, isPessimistic // done in the `defer` function. If the rs is not nil, the detachment will be done in // `rs.Close` in `handleStmt` if handled && sc != nil && rs == nil { - if sc.MemTracker != nil { - sc.MemTracker.Detach() - } - if sc.DiskTracker != nil { - sc.DiskTracker.Detach() - } + sc.DetachMemDiskTracker() } }() @@ -1417,15 +1412,7 @@ func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, err error, hasMoreResults boo func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) { a.FinishExecuteStmt(txnStartTS, lastErr, false) a.logAudit() - // Detach the Memory and disk tracker for the previous stmtCtx from GlobalMemoryUsageTracker and GlobalDiskUsageTracker - if stmtCtx := a.Ctx.GetSessionVars().StmtCtx; stmtCtx != nil { - if stmtCtx.DiskTracker != nil { - stmtCtx.DiskTracker.Detach() - } - if stmtCtx.MemTracker != nil { - stmtCtx.MemTracker.Detach() - } - } + a.Ctx.GetSessionVars().StmtCtx.DetachMemDiskTracker() } // LogSlowQuery is used to print the slow query in the log files. diff --git a/executor/hash_table.go b/executor/hash_table.go index 50acc4447f4df..4afdef58e1973 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -153,7 +153,7 @@ func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, h } func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRow chunk.Row, - probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) { + probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) { // for NAAJ probe row with null, we should match them with all build rows. var ( ok bool @@ -180,16 +180,20 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo // else like // (null, 1, 2), we should use the not-null probe bit to filter rows. Only fetch rows like // ( ? , 1, 2), that exactly with value as 1 and 2 in the second and third join key column. - needCheckProbeRowPos = needCheckProbeRowPos[:0] - needCheckBuildRowPos = needCheckBuildRowPos[:0] + needCheckProbeColPos = needCheckProbeColPos[:0] + needCheckBuildColPos = needCheckBuildColPos[:0] + needCheckBuildTypes = needCheckBuildTypes[:0] + needCheckProbeTypes = needCheckProbeTypes[:0] keyColLen := len(c.hCtx.naKeyColIdx) for i := 0; i < keyColLen; i++ { // since all bucket is from hash table (Not Null), so the buildSideNullBits check is eliminated. if probeKeyNullBits.UnsafeIsSet(i) { continue } - needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) - needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i]) + needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i]) + needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i]) + needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i]) } } var mayMatchedRow chunk.Row @@ -200,7 +204,7 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo } if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 { // check the idxs-th value of the join columns. - ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos) if err != nil { return nil, err } @@ -287,7 +291,7 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk } func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRow chunk.Row, - probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) { + probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) { var ( ok bool err error @@ -306,8 +310,10 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo // case2: left side (probe side) don't have null // left side key <1, 2>, actually we should fetch <1,null>, , from the null bucket because // case like <3,null> is obviously not matched with the probe key. - needCheckProbeRowPos = needCheckProbeRowPos[:0] - needCheckBuildRowPos = needCheckBuildRowPos[:0] + needCheckProbeColPos = needCheckProbeColPos[:0] + needCheckBuildColPos = needCheckBuildColPos[:0] + needCheckBuildTypes = needCheckBuildTypes[:0] + needCheckProbeTypes = needCheckProbeTypes[:0] keyColLen := len(c.hCtx.naKeyColIdx) if probeKeyNullBits != nil { // when the probeKeyNullBits is not nil, it means the probe key has null values, where we should distinguish @@ -325,11 +331,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo if probeKeyNullBits.UnsafeIsSet(i) || nullEntry.nullBitMap.UnsafeIsSet(i) { continue } - needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) - needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i]) + needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i]) + needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i]) + needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i]) } // check the idxs-th value of the join columns. - ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos) if err != nil { return nil, err } @@ -346,11 +354,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo if nullEntry.nullBitMap.UnsafeIsSet(i) { continue } - needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) - needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i]) + needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i]) + needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i]) + needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i]) } // check the idxs-th value of the join columns. - ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos) if err != nil { return nil, err } @@ -366,6 +376,11 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo // matchJoinKey checks if join keys of buildRow and probeRow are logically equal. func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx *hashContext) (ok bool, err error) { + if len(c.hCtx.naKeyColIdx) > 0 { + return codec.EqualChunkRow(c.sc, + buildRow, c.hCtx.allTypes, c.hCtx.naKeyColIdx, + probeRow, probeHCtx.allTypes, probeHCtx.naKeyColIdx) + } return codec.EqualChunkRow(c.sc, buildRow, c.hCtx.allTypes, c.hCtx.keyColIdx, probeRow, probeHCtx.allTypes, probeHCtx.keyColIdx) diff --git a/executor/issuetest/BUILD.bazel b/executor/issuetest/BUILD.bazel index 77bfaf7f11290..1c2955d69327b 100644 --- a/executor/issuetest/BUILD.bazel +++ b/executor/issuetest/BUILD.bazel @@ -15,10 +15,12 @@ go_test( "//parser/auth", "//parser/charset", "//parser/mysql", + "//session", "//sessionctx/variable", "//statistics", "//testkit", "//util", + "//util/memory", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//tikv", diff --git a/executor/issuetest/executor_issue_test.go b/executor/issuetest/executor_issue_test.go index f528a54adb8a0..5b1281c7eb722 100644 --- a/executor/issuetest/executor_issue_test.go +++ b/executor/issuetest/executor_issue_test.go @@ -20,6 +20,7 @@ import ( "math/rand" "strings" "testing" + "time" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/config" @@ -27,10 +28,12 @@ import ( "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/memory" "github.com/stretchr/testify/require" ) @@ -1348,3 +1351,46 @@ func TestIssue40158(t *testing.T) { tk.MustExec("insert into t1 values (1, null);") tk.MustQuery("select * from t1 where c1 is null and _id < 1;").Check(testkit.Rows()) } + +func TestIssue42662(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.Session().GetSessionVars().ConnectionID = 12345 + tk.Session().GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSession, -1) + tk.Session().GetSessionVars().MemTracker.SessionID = 12345 + tk.Session().GetSessionVars().MemTracker.IsRootTrackerOfSess = true + + sm := &testkit.MockSessionManager{ + PS: []*util.ProcessInfo{tk.Session().ShowProcess()}, + } + sm.Conn = make(map[uint64]session.Session) + sm.Conn[tk.Session().GetSessionVars().ConnectionID] = tk.Session() + dom.ServerMemoryLimitHandle().SetSessionManager(sm) + go dom.ServerMemoryLimitHandle().Run() + + tk.MustExec("use test") + tk.MustQuery("select connection_id()").Check(testkit.Rows("12345")) + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int, b int, c int)") + tk.MustExec("create table t2 (a int, b int, c int)") + tk.MustExec("insert into t1 values (1, 1, 1), (1, 2, 2), (2, 1, 3), (2, 2, 4)") + tk.MustExec("insert into t2 values (1, 1, 1), (1, 2, 2), (2, 1, 3), (2, 2, 4)") + // set tidb_server_memory_limit to 1.6GB, tidb_server_memory_limit_sess_min_size to 128MB + tk.MustExec("set global tidb_server_memory_limit='1600MB'") + tk.MustExec("set global tidb_server_memory_limit_sess_min_size=128*1024*1024") + tk.MustExec("set global tidb_mem_oom_action = 'cancel'") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/issue42662_1", `return(true)`)) + // tk.Session() should be marked as MemoryTop1Tracker but not killed. + tk.MustQuery("select /*+ hash_join(t1)*/ * from t1 join t2 on t1.a = t2.a and t1.b = t2.b") + + // try to trigger the kill top1 logic + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/servermemorylimit/issue42662_2", `return(true)`)) + time.Sleep(1 * time.Second) + + // no error should be returned + tk.MustQuery("select count(*) from t1") + tk.MustQuery("select count(*) from t1") + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/issue42662_1")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/servermemorylimit/issue42662_2")) +} diff --git a/executor/join.go b/executor/join.go index 8a762ee6ef851..77054a6d4dedc 100644 --- a/executor/join.go +++ b/executor/join.go @@ -97,8 +97,10 @@ type probeWorker struct { rowIters *chunk.Iterator4Slice rowContainerForProbe *hashRowContainer // for every naaj probe worker, pre-allocate the int slice for store the join column index to check. - needCheckBuildRowPos []int - needCheckProbeRowPos []int + needCheckBuildColPos []int + needCheckProbeColPos []int + needCheckBuildTypes []*types.FieldType + needCheckProbeTypes []*types.FieldType probeChkResourceCh chan *probeChkResource joinChkResourceCh chan *chunk.Chunk probeResultCh chan *chunk.Chunk @@ -176,8 +178,10 @@ func (e *HashJoinExec) Close() error { for _, w := range e.probeWorkers { w.buildSideRows = nil w.buildSideRowPtrs = nil - w.needCheckBuildRowPos = nil - w.needCheckProbeRowPos = nil + w.needCheckBuildColPos = nil + w.needCheckProbeColPos = nil + w.needCheckBuildTypes = nil + w.needCheckProbeTypes = nil w.joinChkResourceCh = nil } @@ -310,6 +314,15 @@ func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chun return } }) + failpoint.Inject("issue42662_1", func(val failpoint.Value) { + if val.(bool) { + if w.hashJoinCtx.sessCtx.GetSessionVars().ConnectionID != 0 { + // consume 170MB memory, this sql should be tracked into MemoryTop1Tracker + w.hashJoinCtx.memTracker.Consume(170 * 1024 * 1024) + } + return + } + }) sessVars := w.hashJoinCtx.sessCtx.GetSessionVars() for { if w.hashJoinCtx.finished.Load() { @@ -604,7 +617,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK } } // step2: match the null bucket secondly. - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows = w.buildSideRows if err != nil { joinResult.err = err @@ -649,7 +662,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK // case1: NOT IN (empty set): ----------------------> result is . // case2: NOT IN (at least a valid inner row) ------------------> result is . // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows := w.buildSideRows if err != nil { joinResult.err = err @@ -679,7 +692,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK } } // Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any). - w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows = w.buildSideRows if err != nil { joinResult.err = err @@ -728,7 +741,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey if probeKeyNullBits == nil { // step1: match null bucket first. // need fetch the "valid" rows every time. (nullBits map check is necessary) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows := w.buildSideRows if err != nil { joinResult.err = err @@ -803,7 +816,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey // case1: NOT IN (empty set): ----------------------> accept rhs row. // case2: NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row. // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows := w.buildSideRows if err != nil { joinResult.err = err @@ -833,7 +846,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey } } // Step2: match all hash table bucket build rows. - w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos) + w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) buildSideRows = w.buildSideRows if err != nil { joinResult.err = err diff --git a/executor/join_test.go b/executor/join_test.go index 6f56d0a18dc8e..b2dceafb23098 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -2909,3 +2909,18 @@ func TestCartesianJoinPanic(t *testing.T) { require.NotNil(t, err) require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) } + +func TestTiDBNAAJ(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=0;") + tk.MustExec("create table t(a decimal(40,0), b bigint(20) not null);") + tk.MustExec("insert into t values(7,8),(7,8),(3,4),(3,4),(9,2),(9,2),(2,0),(2,0),(0,4),(0,4),(8,8),(8,8),(6,1),(6,1),(NULL, 0),(NULL,0);") + tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows( + "0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1")) + tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=1;") + tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows( + "0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1")) +} diff --git a/server/conn.go b/server/conn.go index 843c89e175f2d..8d13186872a6b 100644 --- a/server/conn.go +++ b/server/conn.go @@ -2121,12 +2121,20 @@ func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns [ cc.audit(plugin.Starting) rs, err := cc.ctx.ExecuteStmt(ctx, stmt) reg.End() - // The session tracker detachment from global tracker is solved in the `rs.Close` in most cases. - // If the rs is nil, the detachment will be done in the `handleNoDelay`. + // - If rs is not nil, the statement tracker detachment from session tracker + // is done in the `rs.Close` in most cases. + // - If the rs is nil and err is not nil, the detachment will be done in + // the `handleNoDelay`. if rs != nil { defer terror.Call(rs.Close) } if err != nil { + // If error is returned during the planner phase or the executor.Open + // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker + // will not be detached. We need to detach them manually. + if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil { + sv.StmtCtx.DetachMemDiskTracker() + } return true, err } diff --git a/session/bootstrap.go b/session/bootstrap.go index 5246ea320af0a..5897d24cb8e4c 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -738,6 +738,8 @@ const ( version109 = 109 // version110 sets tidb_server_memory_limit to "80%" version110 = 110 + // version111 sets tidb_stats_load_pseudo_timeout to ON when a cluster upgrades from some version lower than v6.5.0. + version111 = 111 ) // currentBootstrapVersion is defined as a variable, so we can modify its value for testing. @@ -860,6 +862,7 @@ var ( upgradeToVer108, upgradeToVer109, upgradeToVer110, + upgradeToVer111, } ) @@ -2219,6 +2222,15 @@ func upgradeToVer110(s Session, ver int64) { mysql.SystemDB, mysql.GlobalVariablesTable, variable.DefTiDBServerMemoryLimit, variable.TiDBServerMemoryLimit, "0") } +// For users that upgrade TiDB from a 5.4-6.4 version, we want to enable tidb tidb_stats_load_pseudo_timeout by default. +func upgradeToVer111(s Session, ver int64) { + if ver >= version111 { + return + } + mustExecute(s, "REPLACE HIGH_PRIORITY INTO %n.%n VALUES (%?, %?);", + mysql.SystemDB, mysql.GlobalVariablesTable, variable.TiDBStatsLoadPseudoTimeout, 1) +} + func writeOOMAction(s Session) { comment := "oom-action is `log` by default in v3.0.x, `cancel` by default in v4.0.11+" mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE= %?`, diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index 99ad8518a52c5..3f08d9461c17e 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -1490,3 +1490,56 @@ func TestTiDBServerMemoryLimitUpgradeTo651_2(t *testing.T) { require.Equal(t, 2, row.Len()) require.Equal(t, "70%", row.GetString(1)) } + +func TestTiDBStatsLoadPseudoTimeoutUpgradeFrom610To650(t *testing.T) { + ctx := context.Background() + store, _ := createStoreAndBootstrap(t) + defer func() { require.NoError(t, store.Close()) }() + + // upgrade from 6.1 to 6.5+. + ver61 := version91 + seV61 := createSessionAndSetID(t, store) + txn, err := store.Begin() + require.NoError(t, err) + m := meta.NewMeta(txn) + err = m.FinishBootstrap(int64(ver61)) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + mustExec(t, seV61, fmt.Sprintf("update mysql.tidb set variable_value=%d where variable_name='tidb_server_version'", ver61)) + mustExec(t, seV61, fmt.Sprintf("update mysql.GLOBAL_VARIABLES set variable_value='%s' where variable_name='%s'", "0", variable.TiDBStatsLoadPseudoTimeout)) + mustExec(t, seV61, "commit") + unsetStoreBootstrapped(store.UUID()) + ver, err := getBootstrapVersion(seV61) + require.NoError(t, err) + require.Equal(t, int64(ver61), ver) + + // We are now in 6.1, tidb_stats_load_pseudo_timeout is OFF. + res := mustExec(t, seV61, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBStatsLoadPseudoTimeout)) + chk := res.NewChunk(nil) + err = res.Next(ctx, chk) + require.NoError(t, err) + require.Equal(t, 1, chk.NumRows()) + row := chk.GetRow(0) + require.Equal(t, 2, row.Len()) + require.Equal(t, "0", row.GetString(1)) + + // Upgrade to 6.5. + domCurVer, err := BootstrapSession(store) + require.NoError(t, err) + defer domCurVer.Close() + seCurVer := createSessionAndSetID(t, store) + ver, err = getBootstrapVersion(seCurVer) + require.NoError(t, err) + require.Equal(t, currentBootstrapVersion, ver) + + // We are now in 6.5. + res = mustExec(t, seCurVer, fmt.Sprintf("select * from mysql.GLOBAL_VARIABLES where variable_name='%s'", variable.TiDBStatsLoadPseudoTimeout)) + chk = res.NewChunk(nil) + err = res.Next(ctx, chk) + require.NoError(t, err) + require.Equal(t, 1, chk.NumRows()) + row = chk.GetRow(0) + require.Equal(t, 2, row.Len()) + require.Equal(t, "1", row.GetString(1)) +} diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index cdd5728d3ac54..1b2fabd4fcee2 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -1114,6 +1114,19 @@ func (sc *StatementContext) UseDynamicPartitionPrune() bool { return sc.UseDynamicPruneMode } +// DetachMemDiskTracker detaches the memory and disk tracker from the sessionTracker. +func (sc *StatementContext) DetachMemDiskTracker() { + if sc == nil { + return + } + if sc.MemTracker != nil { + sc.MemTracker.Detach() + } + if sc.DiskTracker != nil { + sc.DiskTracker.Detach() + } +} + // CopTasksDetails collects some useful information of cop-tasks during execution. type CopTasksDetails struct { NumCopTasks int diff --git a/testkit/mocksessionmanager.go b/testkit/mocksessionmanager.go index 67280bc2e4cbe..5cf67cec251fb 100644 --- a/testkit/mocksessionmanager.go +++ b/testkit/mocksessionmanager.go @@ -33,7 +33,7 @@ type MockSessionManager struct { SerID uint64 TxnInfo []*txninfo.TxnInfo Dom *domain.Domain - conn map[uint64]session.Session + Conn map[uint64]session.Session mu sync.Mutex } @@ -44,8 +44,8 @@ func (msm *MockSessionManager) ShowTxnList() []*txninfo.TxnInfo { if len(msm.TxnInfo) > 0 { return msm.TxnInfo } - rs := make([]*txninfo.TxnInfo, 0, len(msm.conn)) - for _, se := range msm.conn { + rs := make([]*txninfo.TxnInfo, 0, len(msm.Conn)) + for _, se := range msm.Conn { info := se.TxnInfo() if info != nil { rs = append(rs, info) @@ -66,7 +66,7 @@ func (msm *MockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { return ret } msm.mu.Lock() - for connID, pi := range msm.conn { + for connID, pi := range msm.Conn { ret[connID] = pi.ShowProcess() } msm.mu.Unlock() @@ -89,7 +89,7 @@ func (msm *MockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, boo } msm.mu.Lock() defer msm.mu.Unlock() - if sess := msm.conn[id]; sess != nil { + if sess := msm.Conn[id]; sess != nil { return sess.ShowProcess(), true } if msm.Dom != nil { @@ -130,7 +130,7 @@ func (*MockSessionManager) GetInternalSessionStartTSList() []uint64 { // KillNonFlashbackClusterConn implement SessionManager interface. func (msm *MockSessionManager) KillNonFlashbackClusterConn() { - for _, se := range msm.conn { + for _, se := range msm.Conn { processInfo := se.ShowProcess() ddl, ok := processInfo.StmtCtx.GetPlan().(*core.DDL) if !ok { @@ -148,7 +148,7 @@ func (msm *MockSessionManager) KillNonFlashbackClusterConn() { // CheckOldRunningTxn is to get all startTS of every transactions running in the current internal sessions func (msm *MockSessionManager) CheckOldRunningTxn(job2ver map[int64]int64, job2ids map[int64]string) { msm.mu.Lock() - for _, se := range msm.conn { + for _, se := range msm.Conn { session.RemoveLockDDLJobs(se, job2ver, job2ids) } msm.mu.Unlock() diff --git a/testkit/testkit.go b/testkit/testkit.go index db86548ee3bfd..e4461fee82ac6 100644 --- a/testkit/testkit.go +++ b/testkit/testkit.go @@ -69,10 +69,10 @@ func NewTestKit(t testing.TB, store kv.Storage) *TestKit { mockSm, ok := sm.(*MockSessionManager) if ok { mockSm.mu.Lock() - if mockSm.conn == nil { - mockSm.conn = make(map[uint64]session.Session) + if mockSm.Conn == nil { + mockSm.Conn = make(map[uint64]session.Session) } - mockSm.conn[tk.session.GetSessionVars().ConnectionID] = tk.session + mockSm.Conn[tk.session.GetSessionVars().ConnectionID] = tk.session mockSm.mu.Unlock() } tk.session.SetSessionManager(sm) diff --git a/util/memory/tracker.go b/util/memory/tracker.go index b4ffea612ec53..6ee472599e277 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -302,6 +302,9 @@ func (t *Tracker) AttachTo(parent *Tracker) { // Detach de-attach the tracker child from its parent, then set its parent property as nil func (t *Tracker) Detach() { + if t == nil { + return + } parent := t.getParent() if parent == nil { return @@ -446,6 +449,7 @@ func (t *Tracker) Consume(bs int64) { currentAction = nextAction nextAction = currentAction.GetFallback() } + logutil.BgLogger().Warn("global memory controller, lastAction", zap.Any("action", currentAction)) currentAction.Action(tracker) } } @@ -471,6 +475,7 @@ func (t *Tracker) Consume(bs int64) { } oldTracker = MemUsageTop1Tracker.Load() } + logutil.BgLogger().Error("global memory controller, update the Top1 session", zap.Int64("memUsage", memUsage), zap.Uint64("conn", sessionRootTracker.SessionID), zap.Uint64("limitSessMinSize", limitSessMinSize)) } } diff --git a/util/servermemorylimit/BUILD.bazel b/util/servermemorylimit/BUILD.bazel index 0d2c4d4f3cb59..881894c0e073b 100644 --- a/util/servermemorylimit/BUILD.bazel +++ b/util/servermemorylimit/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "//util", "//util/logutil", "//util/memory", + "@com_github_pingcap_failpoint//:failpoint", "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", ], diff --git a/util/servermemorylimit/servermemorylimit.go b/util/servermemorylimit/servermemorylimit.go index 511a86703db17..38a679a4ad755 100644 --- a/util/servermemorylimit/servermemorylimit.go +++ b/util/servermemorylimit/servermemorylimit.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" @@ -88,6 +89,15 @@ type sessionToBeKilled struct { lastLogTime time.Time } +func (s *sessionToBeKilled) reset() { + s.isKilling = false + s.sqlStartTime = time.Time{} + s.sessionID = 0 + s.sessionTracker = nil + s.killStartTime = time.Time{} + s.lastLogTime = time.Time{} +} + func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { if s.isKilling { if info, ok := sm.GetProcessInfo(s.sessionID); ok { @@ -104,7 +114,7 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { return } } - s.isKilling = false + s.reset() IsKilling.Store(false) memory.MemUsageTop1Tracker.CompareAndSwap(s.sessionTracker, nil) //nolint: all_revive,revive @@ -115,14 +125,25 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { if bt == 0 { return } + failpoint.Inject("issue42662_2", func(val failpoint.Value) { + if val.(bool) { + bt = 1 + } + }) instanceStats := memory.ReadMemStats() if instanceStats.HeapInuse > MemoryMaxUsed.Load() { MemoryMaxUsed.Store(instanceStats.HeapInuse) } + limitSessMinSize := memory.ServerMemoryLimitSessMinSize.Load() if instanceStats.HeapInuse > bt { t := memory.MemUsageTop1Tracker.Load() if t != nil { - if info, ok := sm.GetProcessInfo(t.SessionID); ok { + memUsage := t.BytesConsumed() + // If the memory usage of the top1 session is less than tidb_server_memory_limit_sess_min_size, we do not need to kill it. + if uint64(memUsage) < limitSessMinSize { + memory.MemUsageTop1Tracker.CompareAndSwap(t, nil) + t = nil + } else if info, ok := sm.GetProcessInfo(t.SessionID); ok { logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer", zap.Uint64("connID", info.ID), zap.String("sql digest", info.Digest), @@ -146,6 +167,17 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { s.killStartTime = time.Now() } } + // If no one larger than tidb_server_memory_limit_sess_min_size is found, we will not kill any one. + if t == nil { + if s.lastLogTime.IsZero() { + s.lastLogTime = time.Now() + } + if time.Since(s.lastLogTime) < 5*time.Second { + return + } + logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer, but no one larger than tidb_server_memory_limit_sess_min_size is found", zap.Uint64("tidb_server_memory_limit_sess_min_size", limitSessMinSize)) + s.lastLogTime = time.Now() + } } }