diff --git a/ddl/ddl_db_change_test.go b/ddl/ddl_db_change_test.go index 3f76af2f0df92..fc0badc25c532 100644 --- a/ddl/ddl_db_change_test.go +++ b/ddl/ddl_db_change_test.go @@ -258,8 +258,9 @@ func (t *testExecInfo) compileSQL(idx int) (err error) { ctx := context.TODO() se.PrepareTxnCtx(ctx) sctx := se.(sessionctx.Context) - executor.ResetStmtCtx(sctx, c.rawStmt) - + if err = executor.ResetStmtCtx(sctx, c.rawStmt); err != nil { + return errors.Trace(err) + } c.stmt, err = compiler.Compile(ctx, c.rawStmt) if err != nil { return errors.Trace(err) diff --git a/executor/prepared.go b/executor/prepared.go index 3e3b094ce3d35..4a25584cc5fa3 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -17,6 +17,7 @@ import ( "math" "sort" + "fmt" "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/config" @@ -206,7 +207,9 @@ func (e *ExecuteExec) Build() error { return errors.Trace(b.err) } e.stmtExec = stmtExec - ResetStmtCtx(e.ctx, e.stmt) + if err = ResetStmtCtx(e.ctx, e.stmt); err != nil { + return err + } CountStmtNode(e.stmt, e.ctx.GetSessionVars().InRestrictedSQL) logExpensiveQuery(e.stmt, e.plan) return nil @@ -258,7 +261,7 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter // ResetStmtCtx resets the StmtContext. // Before every execution, we must clear statement context. -func ResetStmtCtx(ctx sessionctx.Context, s ast.StmtNode) { +func ResetStmtCtx(ctx sessionctx.Context, s ast.StmtNode) (err error) { sessVars := ctx.GetSessionVars() sc := new(stmtctx.StatementContext) sc.TimeZone = sessVars.GetTimeZone() @@ -340,6 +343,15 @@ func ResetStmtCtx(ctx sessionctx.Context, s ast.StmtNode) { sessVars.LastInsertID = 0 } sessVars.ResetPrevAffectedRows() + err = sessVars.SetSystemVar("warning_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(false))) + if err != nil { + return errors.Trace(err) + } + err = sessVars.SetSystemVar("error_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(true))) + if err != nil { + return errors.Trace(err) + } sessVars.InsertID = 0 sessVars.StmtCtx = sc + return } diff --git a/executor/show_test.go b/executor/show_test.go index e5459811fcbec..4f19df04ae953 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -475,14 +475,17 @@ func (s *testSuite) TestShowWarnings(c *C) { tk.Exec(testSQL) c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(1)) tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Error|1050|Table 'test.show_warnings' already exists")) + tk.MustQuery("select @@error_count").Check(testutil.RowsWithSep("|", "1")) - // Test Warning level 'Level' + // Test Warning level 'Note' testSQL = `create table show_warnings_2 (a int)` tk.MustExec(testSQL) testSQL = `create table if not exists show_warnings_2 like show_warnings` tk.Exec(testSQL) c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(1)) tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Note|1050|Table 'test.show_warnings_2' already exists")) + tk.MustQuery("select @@warning_count").Check(testutil.RowsWithSep("|", "1")) + tk.MustQuery("select @@warning_count").Check(testutil.RowsWithSep("|", "0")) } func (s *testSuite) TestShowErrors(c *C) { diff --git a/session/session.go b/session/session.go index d4cdf73c22eb9..e4046cdd508a1 100644 --- a/session/session.go +++ b/session/session.go @@ -833,7 +833,9 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []ast.Rec } s.PrepareTxnCtx(ctx) - executor.ResetStmtCtx(s, stmtNode) + if err = executor.ResetStmtCtx(s, stmtNode); err != nil { + return nil, errors.Trace(err) + } if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets); err != nil { return nil, errors.Trace(err) } @@ -862,7 +864,9 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []ast.Rec // Step2: Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). startTS = time.Now() // Some executions are done in compile stage, so we reset them before compile. - executor.ResetStmtCtx(s, stmtNode) + if err := executor.ResetStmtCtx(s, stmtNode); err != nil { + return nil, errors.Trace(err) + } stmt, err := compiler.Compile(ctx, stmtNode) if err != nil { s.rollbackOnError(ctx) diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 63ec56e96bbd6..09c3000ceb80a 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -125,6 +125,27 @@ func (sc *StatementContext) WarningCount() uint16 { return wc } +// NumWarnings gets warning count. It's different from `WarningCount` in that +// `WarningCount` return the warning count of the last executed command, so if +// the last command is a SHOW statement, `WarningCount` return 0. On the other +// hand, `NumWarnings` always return number of warnings(or errors if `errOnly` +// is set). +func (sc *StatementContext) NumWarnings(errOnly bool) uint16 { + var wc uint16 + sc.mu.Lock() + defer sc.mu.Unlock() + if errOnly { + for _, warn := range sc.mu.warnings { + if warn.Level == WarnLevelError { + wc++ + } + } + } else { + wc = uint16(len(sc.mu.warnings)) + } + return wc +} + // SetWarnings sets warnings. func (sc *StatementContext) SetWarnings(warns []SQLWarn) { sc.mu.Lock() diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 32d78666ba398..1af501dde1bf3 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -606,6 +606,8 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, "min_examined_row_limit", "0"}, {ScopeGlobal, "sync_frm", "ON"}, {ScopeGlobal, "innodb_online_alter_log_max_size", "134217728"}, + {ScopeSession, "warning_count", "0"}, + {ScopeSession, "error_count", "0"}, /* TiDB specific variables */ {ScopeSession, TiDBSnapshot, ""}, {ScopeSession, TiDBImportingData, "0"},