Skip to content

Commit

Permalink
session: fix select for update statement can't get stmt-count-limit e…
Browse files Browse the repository at this point in the history
…rror (#48412) (#48469)

close #48411
  • Loading branch information
ti-chi-bot authored Nov 10, 2023
1 parent 76fab0e commit 4bddd59
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 5 deletions.
1 change: 1 addition & 0 deletions pkg/server/internal/testserverclient/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ go_library(
importpath = "github.com/pingcap/tidb/pkg/server/internal/testserverclient",
visibility = ["//pkg/server:__subpackages__"],
deps = [
"//pkg/config",
"//pkg/errno",
"//pkg/kv",
"//pkg/parser/mysql",
Expand Down
74 changes: 74 additions & 0 deletions pkg/server/internal/testserverclient/server_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/log"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/errno"
"github.com/pingcap/tidb/pkg/kv"
tmysql "github.com/pingcap/tidb/pkg/parser/mysql"
Expand Down Expand Up @@ -2446,4 +2447,77 @@ func (cli *TestServerClient) RunTestInfoschemaClientErrors(t *testing.T) {
})
}

func (cli *TestServerClient) RunTestStmtCountLimit(t *testing.T) {
originalStmtCountLimit := config.GetGlobalConfig().Performance.StmtCountLimit
config.UpdateGlobal(func(conf *config.Config) {
conf.Performance.StmtCountLimit = 3
})
defer func() {
config.UpdateGlobal(func(conf *config.Config) {
conf.Performance.StmtCountLimit = originalStmtCountLimit
})
}()

cli.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
dbt.MustExec("create table t (id int key);")
dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;")
dbt.MustExec("set autocommit=0;")
dbt.MustExec("begin optimistic;")
dbt.MustExec("insert into t values (1);")
dbt.MustExec("insert into t values (2);")
_, err := dbt.GetDB().Query("select * from t for update;")
require.Error(t, err)
require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error())
dbt.MustExec("insert into t values (3);")
dbt.MustExec("commit;")
rows := dbt.MustQuery("select * from t;")
var id int
count := 0
for rows.Next() {
rows.Scan(&id)
count++
}
require.NoError(t, rows.Close())
require.Equal(t, 3, id)
require.Equal(t, 1, count)

dbt.MustExec("delete from t;")
dbt.MustExec("commit;")
dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;")
dbt.MustExec("set autocommit=0;")
dbt.MustExec("begin optimistic;")
dbt.MustExec("insert into t values (1);")
dbt.MustExec("insert into t values (2);")
_, err = dbt.GetDB().Exec("insert into t values (3);")
require.Error(t, err)
require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error())
dbt.MustExec("commit;")
rows = dbt.MustQuery("select count(*) from t;")
for rows.Next() {
rows.Scan(&count)
}
require.NoError(t, rows.Close())
require.Equal(t, 0, count)

dbt.MustExec("delete from t;")
dbt.MustExec("commit;")
dbt.MustExec("set @@tidb_batch_commit=1;")
dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;")
dbt.MustExec("set autocommit=0;")
dbt.MustExec("begin optimistic;")
dbt.MustExec("insert into t values (1);")
dbt.MustExec("insert into t values (2);")
dbt.MustExec("insert into t values (3);")
dbt.MustExec("insert into t values (4);")
dbt.MustExec("insert into t values (5);")
dbt.MustExec("commit;")
rows = dbt.MustQuery("select count(*) from t;")
for rows.Next() {
rows.Scan(&count)
}
require.NoError(t, rows.Close())
require.Equal(t, 5, count)
})
}

//revive:enable:exported
5 changes: 5 additions & 0 deletions pkg/server/tests/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,11 @@ func TestSumAvg(t *testing.T) {
ts.RunTestSumAvg(t)
}

func TestStmtCountLimit(t *testing.T) {
ts := createTidbTestSuite(t)
ts.RunTestStmtCountLimit(t)
}

func TestNullFlag(t *testing.T) {
ts := createTidbTestSuite(t)

Expand Down
8 changes: 8 additions & 0 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,14 @@ func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.
if err != nil {
return nil, err
}
if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) {
// Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check,
// otherwise, the stmt won't be add into stmt history, and also don't need check.
// About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit
if err := checkStmtLimit(ctx, se, false); err != nil {
return nil, err
}
}

rs, err = s.Exec(ctx)
se.updateTelemetryMetric(s.(*executor.ExecStmt))
Expand Down
10 changes: 10 additions & 0 deletions pkg/session/test/txn/txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,16 @@ func TestBatchCommit(t *testing.T) {
tk.MustExec("insert into t values (7)")
tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7"))

tk.MustExec("delete from t")
tk.MustExec("commit")
tk.MustExec("begin")
tk.MustExec("explain analyze insert into t values (5)")
tk1.MustQuery("select * from t").Check(testkit.Rows())
tk.MustExec("explain analyze insert into t values (6)")
tk1.MustQuery("select * from t").Check(testkit.Rows())
tk.MustExec("explain analyze insert into t values (7)")
tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7"))

// The session is still in transaction.
tk.MustExec("insert into t values (8)")
tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7"))
Expand Down
22 changes: 17 additions & 5 deletions pkg/session/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.St
if err != nil {
return err
}
return checkStmtLimit(ctx, se)
return checkStmtLimit(ctx, se, true)
}

func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error {
Expand Down Expand Up @@ -305,18 +305,29 @@ func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql s
return nil
}

func checkStmtLimit(ctx context.Context, se *session) error {
func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error {
// If the user insert, insert, insert ... but never commit, TiDB would OOM.
// So we limit the statement count in a transaction here.
var err error
sessVars := se.GetSessionVars()
history := GetHistory(se)
if history.Count() > int(config.GetGlobalConfig().Performance.StmtCountLimit) {
stmtCount := history.Count()
if !isFinish {
// history stmt count + current stmt, since current stmt is not finish, it has not add to history.
stmtCount++
}
if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) {
if !sessVars.BatchCommit {
se.RollbackTxn(ctx)
return errors.Errorf("statement count %d exceeds the transaction limitation, autocommit = %t",
history.Count(), sessVars.IsAutocommit())
return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t",
stmtCount, sessVars.IsAutocommit())
}
if !isFinish {
// if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit.
return nil
}
// If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true,
// then commit the current transaction and create a new transaction.
err = sessiontxn.NewTxn(ctx, se)
// The transaction does not committed yet, we need to keep it in transaction.
// The last history could not be "commit"/"rollback" statement.
Expand All @@ -328,6 +339,7 @@ func checkStmtLimit(ctx context.Context, se *session) error {
}

// GetHistory get all stmtHistory in current txn. Exported only for test.
// If stmtHistory is nil, will create a new one for current txn.
func GetHistory(ctx sessionctx.Context) *StmtHistory {
hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory)
if ok {
Expand Down

0 comments on commit 4bddd59

Please sign in to comment.