diff --git a/domain/domain.go b/domain/domain.go index 649b3e7b1287d..1742102cc9445 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1215,9 +1215,12 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) { } } // update locally - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`FLUSH PRIVILEGES`) - if err != nil { - logutil.BgLogger().Error("unable to update privileges", zap.Error(err)) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + if stmt, err := exec.ParseWithParams(context.Background(), `FLUSH PRIVILEGES`); err == nil { + _, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) + if err != nil { + logutil.BgLogger().Error("unable to update privileges", zap.Error(err)) + } } } diff --git a/server/sql_info_fetcher.go b/server/sql_info_fetcher.go index 34236f8eabe7d..76ba5d6682341 100644 --- a/server/sql_info_fetcher.go +++ b/server/sql_info_fetcher.go @@ -88,7 +88,7 @@ func (sh *sqlInfoFetcher) zipInfoForSQL(w http.ResponseWriter, r *http.Request) timeoutString := r.FormValue("timeout") curDB := strings.ToLower(r.FormValue("current_db")) if curDB != "" { - _, err = sh.s.Execute(reqCtx, fmt.Sprintf("use %v", curDB)) + _, err = sh.s.ExecuteInternal(reqCtx, "use %n", curDB) if err != nil { serveError(w, http.StatusInternalServerError, fmt.Sprintf("use database %v failed, err: %v", curDB, err)) return diff --git a/session/session.go b/session/session.go index 8ef1b1a6d740e..0c76f0015c9da 100644 --- a/session/session.go +++ b/session/session.go @@ -104,16 +104,6 @@ type Session interface { ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) // Parse is deprecated, use ParseWithParams() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) - // ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4. - // It works like printf() in c, there are following format specifiers: - // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) - // 2. %%: output % - // 3. %n: for identifiers, for example ("use %n", db) - // - // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". - // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. - // This function only saves you from processing potentially unsafe parameters. - ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error) String() string // String is used to debug. @@ -1286,6 +1276,77 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter return stmts[0], nil } +// ExecRestrictedStmt implements RestrictedSQLExecutor interface. +func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( + []chunk.Row, []*ast.ResultField, error) { + var execOption sqlexec.ExecOption + for _, opt := range opts { + opt(&execOption) + } + // Use special session to execute the sql. + tmp, err := s.sysSessionPool().Get() + if err != nil { + return nil, nil, err + } + defer s.sysSessionPool().Put(tmp) + se := tmp.(*session) + + startTime := time.Now() + // The special session will share the `InspectionTableCache` with current session + // if the current session in inspection mode. + if cache := s.sessionVars.InspectionTableCache; cache != nil { + se.sessionVars.InspectionTableCache = cache + defer func() { se.sessionVars.InspectionTableCache = nil }() + } + defer func() { + if !execOption.IgnoreWarning { + if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + s.GetSessionVars().StmtCtx.AppendWarnings(warnings) + } + } + }() + + if execOption.SnapshotTS != 0 { + se.sessionVars.SnapshotInfoschema, err = domain.GetDomain(s).GetSnapshotInfoSchema(execOption.SnapshotTS) + if err != nil { + return nil, nil, err + } + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { + return nil, nil, err + } + defer func() { + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) + } + se.sessionVars.SnapshotInfoschema = nil + }() + } + + metrics.SessionRestrictedSQLCounter.Inc() + + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + rs, err := se.ExecuteStmt(ctx, stmtNode) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs) + if err != nil { + return nil, nil, err + } + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) + return rows, rs.Fields(), err +} + func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context())) diff --git a/session/session_test.go b/session/session_test.go index 9a6b78c46beb5..a32e9e05349ed 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -3573,6 +3573,7 @@ func (s *testSessionSuite2) TestRetryCommitWithSet(c *C) { func (s *testSessionSerialSuite) TestParseWithParams(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) se := tk.Se + exec := se.(sqlexec.RestrictedSQLExecutor) // test compatibility with ExcuteInternal origin := se.GetSessionVars().InRestrictedSQL @@ -3580,11 +3581,11 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { defer func() { se.GetSessionVars().InRestrictedSQL = origin }() - _, err := se.ParseWithParams(context.Background(), "SELECT 4") + _, err := exec.ParseWithParams(context.Background(), "SELECT 4") c.Assert(err, IsNil) // test charset attack - stmt, err := se.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + stmt, err := exec.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") c.Assert(err, IsNil) var sb strings.Builder @@ -3594,15 +3595,15 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") // test invalid sql - _, err = se.ParseWithParams(context.Background(), "SELECT") + _, err = exec.ParseWithParams(context.Background(), "SELECT") c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") // test invalid arguments to escape - _, err = se.ParseWithParams(context.Background(), "SELECT %?") + _, err = exec.ParseWithParams(context.Background(), "SELECT %?") c.Assert(err, ErrorMatches, "missing arguments.*") // test noescape - stmt, err = se.ParseWithParams(context.TODO(), "SELECT 3") + stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3") c.Assert(err, IsNil) sb.Reset() diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index 0ceea6292019b..b606dec2b1fac 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -285,7 +285,7 @@ func (w *GCWorker) prepare() (bool, uint64, error) { ctx := context.Background() se := createSession(w.store) defer se.Close() - _, err := se.Execute(ctx, "BEGIN") + _, err := se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, 0, errors.Trace(err) } @@ -1599,7 +1599,7 @@ func (w *GCWorker) checkLeader() (bool, error) { defer se.Close() ctx := context.Background() - _, err := se.Execute(ctx, "BEGIN") + _, err := se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, errors.Trace(err) } @@ -1624,7 +1624,7 @@ func (w *GCWorker) checkLeader() (bool, error) { se.RollbackTxn(ctx) - _, err = se.Execute(ctx, "BEGIN") + _, err = se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, errors.Trace(err) } @@ -1732,8 +1732,7 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { ctx := context.Background() se := createSession(w.store) defer se.Close() - stmt := fmt.Sprintf(`SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name='%s' FOR UPDATE`, key) - rs, err := se.Execute(ctx, stmt) + rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key) if len(rs) > 0 { defer terror.Call(rs[0].Close) } @@ -1758,13 +1757,14 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { } func (w *GCWorker) saveValueToSysTable(key, value string) error { - stmt := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + const stmt = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) ON DUPLICATE KEY - UPDATE variable_value = '%[2]s', comment = '%[3]s'`, - key, value, gcVariableComments[key]) + UPDATE variable_value = %?, comment = %?` se := createSession(w.store) defer se.Close() - _, err := se.Execute(context.Background(), stmt) + _, err := se.ExecuteInternal(context.Background(), stmt, + key, value, gcVariableComments[key], + value, gcVariableComments[key]) logutil.BgLogger().Debug("[gc worker] save kv", zap.String("key", key), zap.String("value", value), diff --git a/util/admin/admin.go b/util/admin/admin.go index a5f46de92f9aa..7419a1ab6d0b4 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -16,12 +16,12 @@ package admin import ( "context" "encoding/json" - "fmt" "math" "sort" "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/errno" @@ -288,13 +288,13 @@ type RecordData struct { Values []types.Datum } -func getCount(ctx sessionctx.Context, sql string) (int64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql) +func getCount(exec sqlexec.RestrictedSQLExecutor, stmt ast.StmtNode, snapshot uint64) (int64, error) { + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) if err != nil { return 0, errors.Trace(err) } if len(rows) != 1 { - return 0, errors.Errorf("can not get count, sql %s result rows %d", sql, len(rows)) + return 0, errors.Errorf("can not get count, rows count = %d", len(rows)) } return rows[0].GetInt64(0), nil } @@ -313,14 +313,34 @@ const ( // otherwise it returns an error and the corresponding index's offset. func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices []string) (byte, int, error) { // Add `` for some names like `table name`. - sql := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX()", dbName, tableName) - tblCnt, err := getCount(ctx, sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) + if err != nil { + return 0, 0, errors.Trace(err) + } + + var snapshot uint64 + txn, err := ctx.Txn(false) + if err != nil { + return 0, 0, err + } + if txn.Valid() { + snapshot = txn.StartTS() + } + if ctx.GetSessionVars().SnapshotTS != 0 { + snapshot = ctx.GetSessionVars().SnapshotTS + } + + tblCnt, err := getCount(exec, stmt, snapshot) if err != nil { return 0, 0, errors.Trace(err) } for i, idx := range indices { - sql = fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX(`%s`)", dbName, tableName, idx) - idxCnt, err := getCount(ctx, sql) + stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) + if err != nil { + return 0, i, errors.Trace(err) + } + idxCnt, err := getCount(exec, stmt, snapshot) if err != nil { return 0, i, errors.Trace(err) } diff --git a/util/gcutil/gcutil.go b/util/gcutil/gcutil.go index f265e4dd0603f..6d3c116cdb0a4 100644 --- a/util/gcutil/gcutil.go +++ b/util/gcutil/gcutil.go @@ -14,7 +14,7 @@ package gcutil import ( - "fmt" + "context" "github.com/pingcap/errors" "github.com/pingcap/parser/model" @@ -25,18 +25,20 @@ import ( ) const ( - selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name='%s'` - insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') - ON DUPLICATE KEY - UPDATE variable_value = '%[2]s', comment = '%[3]s'` + insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) + ON DUPLICATE KEY UPDATE variable_value = %?, comment = %?` + selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name=%?` ) // CheckGCEnable is use to check whether GC is enable. func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) { - sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_enable") - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) - if err != nil { - return false, errors.Trace(err) + stmt, err1 := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_enable") + if err1 != nil { + return false, errors.Trace(err1) + } + rows, _, err2 := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + if err1 != nil { + return false, errors.Trace(err2) } if len(rows) != 1 { return false, errors.New("can not get 'tikv_gc_enable'") @@ -46,15 +48,19 @@ func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) { // DisableGC will disable GC enable variable. func DisableGC(ctx sessionctx.Context) error { - sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status") - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status", "false", "Current GC enable status") + if err == nil { + _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + } return errors.Trace(err) } // EnableGC will enable GC enable variable. func EnableGC(ctx sessionctx.Context) error { - sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status") - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status", "true", "Current GC enable status") + if err == nil { + _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + } return errors.Trace(err) } @@ -80,8 +86,12 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error { // GetGCSafePoint loads GC safe point time from mysql.tidb. func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) { - sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_safe_point") - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point") + if err != nil { + return 0, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return 0, errors.Trace(err) } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 6ea589ec5ff5d..92d6958a38667 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -42,6 +42,43 @@ type RestrictedSQLExecutor interface { // If current session sets the snapshot timestamp, then execute with this snapshot timestamp. // Otherwise, execute with the current transaction start timestamp if the transaction is valid. ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast.ResultField, error) + + // The above methods are all deprecated. + // After the refactor finish, they will be removed. + + // ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4. + // It works like printf() in c, there are following format specifiers: + // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) + // 2. %%: output % + // 3. %n: for identifiers, for example ("use %n", db) + // + // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". + // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. + // This function only saves you from processing potentially unsafe parameters. + ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) + // ExecRestrictedStmt run sql statement in ctx with some restriction. + ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) +} + +// ExecOption is a struct defined for ExecRestrictedSQLWithContext option. +type ExecOption struct { + IgnoreWarning bool + SnapshotTS uint64 +} + +// OptionFuncAlias is defined for the optional paramater of ExecRestrictedSQLWithContext. +type OptionFuncAlias = func(option *ExecOption) + +// ExecOptionIgnoreWarning tells ExecRestrictedSQLWithContext to ignore the warnings. +var ExecOptionIgnoreWarning OptionFuncAlias = func(option *ExecOption) { + option.IgnoreWarning = true +} + +// ExecOptionWithSnapshot tells ExecRestrictedSQLWithContext to use a snapshot. +func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias { + return func(option *ExecOption) { + option.SnapshotTS = snapshot + } } // SQLExecutor is an interface provides executing normal sql statement.