Skip to content

Commit

Permalink
Merge branch 'release-4.0' into release-4.0-405a5d009dcd
Browse files Browse the repository at this point in the history
  • Loading branch information
crazycs520 authored Mar 4, 2021
2 parents ebc8150 + ae010ce commit 7daa3ec
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 51 deletions.
9 changes: 6 additions & 3 deletions domain/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -1215,9 +1215,12 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) {
}
}
// update locally
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`FLUSH PRIVILEGES`)
if err != nil {
logutil.BgLogger().Error("unable to update privileges", zap.Error(err))
exec := ctx.(sqlexec.RestrictedSQLExecutor)
if stmt, err := exec.ParseWithParams(context.Background(), `FLUSH PRIVILEGES`); err == nil {
_, _, err := exec.ExecRestrictedStmt(context.Background(), stmt)
if err != nil {
logutil.BgLogger().Error("unable to update privileges", zap.Error(err))
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion server/sql_info_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (sh *sqlInfoFetcher) zipInfoForSQL(w http.ResponseWriter, r *http.Request)
timeoutString := r.FormValue("timeout")
curDB := strings.ToLower(r.FormValue("current_db"))
if curDB != "" {
_, err = sh.s.Execute(reqCtx, fmt.Sprintf("use %v", curDB))
_, err = sh.s.ExecuteInternal(reqCtx, "use %n", curDB)
if err != nil {
serveError(w, http.StatusInternalServerError, fmt.Sprintf("use database %v failed, err: %v", curDB, err))
return
Expand Down
81 changes: 71 additions & 10 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,6 @@ type Session interface {
ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error)
// Parse is deprecated, use ParseWithParams() instead.
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
// ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4.
// It works like printf() in c, there are following format specifiers:
// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..)
// 2. %%: output %
// 3. %n: for identifiers, for example ("use %n", db)
//
// Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'".
// One argument should be a standalone entity. It should not "concat" with other placeholders and characters.
// This function only saves you from processing potentially unsafe parameters.
ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error)
String() string // String is used to debug.
Expand Down Expand Up @@ -1286,6 +1276,77 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter
return stmts[0], nil
}

// ExecRestrictedStmt implements RestrictedSQLExecutor interface.
func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) (
[]chunk.Row, []*ast.ResultField, error) {
var execOption sqlexec.ExecOption
for _, opt := range opts {
opt(&execOption)
}
// Use special session to execute the sql.
tmp, err := s.sysSessionPool().Get()
if err != nil {
return nil, nil, err
}
defer s.sysSessionPool().Put(tmp)
se := tmp.(*session)

startTime := time.Now()
// The special session will share the `InspectionTableCache` with current session
// if the current session in inspection mode.
if cache := s.sessionVars.InspectionTableCache; cache != nil {
se.sessionVars.InspectionTableCache = cache
defer func() { se.sessionVars.InspectionTableCache = nil }()
}
defer func() {
if !execOption.IgnoreWarning {
if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 {
warnings := se.GetSessionVars().StmtCtx.GetWarnings()
s.GetSessionVars().StmtCtx.AppendWarnings(warnings)
}
}
}()

if execOption.SnapshotTS != 0 {
se.sessionVars.SnapshotInfoschema, err = domain.GetDomain(s).GetSnapshotInfoSchema(execOption.SnapshotTS)
if err != nil {
return nil, nil, err
}
if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil {
return nil, nil, err
}
defer func() {
if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil {
logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err))
}
se.sessionVars.SnapshotInfoschema = nil
}()
}

metrics.SessionRestrictedSQLCounter.Inc()

ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
rs, err := se.ExecuteStmt(ctx, stmtNode)
if err != nil {
se.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, nil, err
}
defer func() {
if closeErr := rs.Close(); closeErr != nil {
err = closeErr
}
}()
var rows []chunk.Row
rows, err = drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}
metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, rs.Fields(), err
}

func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context()))
Expand Down
11 changes: 6 additions & 5 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3573,18 +3573,19 @@ func (s *testSessionSuite2) TestRetryCommitWithSet(c *C) {
func (s *testSessionSerialSuite) TestParseWithParams(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
se := tk.Se
exec := se.(sqlexec.RestrictedSQLExecutor)

// test compatibility with ExcuteInternal
origin := se.GetSessionVars().InRestrictedSQL
se.GetSessionVars().InRestrictedSQL = true
defer func() {
se.GetSessionVars().InRestrictedSQL = origin
}()
_, err := se.ParseWithParams(context.Background(), "SELECT 4")
_, err := exec.ParseWithParams(context.Background(), "SELECT 4")
c.Assert(err, IsNil)

// test charset attack
stmt, err := se.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*")
stmt, err := exec.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*")
c.Assert(err, IsNil)

var sb strings.Builder
Expand All @@ -3594,15 +3595,15 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) {
c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1")

// test invalid sql
_, err = se.ParseWithParams(context.Background(), "SELECT")
_, err = exec.ParseWithParams(context.Background(), "SELECT")
c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*")

// test invalid arguments to escape
_, err = se.ParseWithParams(context.Background(), "SELECT %?")
_, err = exec.ParseWithParams(context.Background(), "SELECT %?")
c.Assert(err, ErrorMatches, "missing arguments.*")

// test noescape
stmt, err = se.ParseWithParams(context.TODO(), "SELECT 3")
stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3")
c.Assert(err, IsNil)

sb.Reset()
Expand Down
18 changes: 9 additions & 9 deletions store/tikv/gcworker/gc_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func (w *GCWorker) prepare() (bool, uint64, error) {
ctx := context.Background()
se := createSession(w.store)
defer se.Close()
_, err := se.Execute(ctx, "BEGIN")
_, err := se.ExecuteInternal(ctx, "BEGIN")
if err != nil {
return false, 0, errors.Trace(err)
}
Expand Down Expand Up @@ -1599,7 +1599,7 @@ func (w *GCWorker) checkLeader() (bool, error) {
defer se.Close()

ctx := context.Background()
_, err := se.Execute(ctx, "BEGIN")
_, err := se.ExecuteInternal(ctx, "BEGIN")
if err != nil {
return false, errors.Trace(err)
}
Expand All @@ -1624,7 +1624,7 @@ func (w *GCWorker) checkLeader() (bool, error) {

se.RollbackTxn(ctx)

_, err = se.Execute(ctx, "BEGIN")
_, err = se.ExecuteInternal(ctx, "BEGIN")
if err != nil {
return false, errors.Trace(err)
}
Expand Down Expand Up @@ -1732,8 +1732,7 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) {
ctx := context.Background()
se := createSession(w.store)
defer se.Close()
stmt := fmt.Sprintf(`SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name='%s' FOR UPDATE`, key)
rs, err := se.Execute(ctx, stmt)
rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key)
if len(rs) > 0 {
defer terror.Call(rs[0].Close)
}
Expand All @@ -1758,13 +1757,14 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) {
}

func (w *GCWorker) saveValueToSysTable(key, value string) error {
stmt := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s')
const stmt = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?)
ON DUPLICATE KEY
UPDATE variable_value = '%[2]s', comment = '%[3]s'`,
key, value, gcVariableComments[key])
UPDATE variable_value = %?, comment = %?`
se := createSession(w.store)
defer se.Close()
_, err := se.Execute(context.Background(), stmt)
_, err := se.ExecuteInternal(context.Background(), stmt,
key, value, gcVariableComments[key],
value, gcVariableComments[key])
logutil.BgLogger().Debug("[gc worker] save kv",
zap.String("key", key),
zap.String("value", value),
Expand Down
36 changes: 28 additions & 8 deletions util/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ package admin
import (
"context"
"encoding/json"
"fmt"
"math"
"sort"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/errno"
Expand Down Expand Up @@ -288,13 +288,13 @@ type RecordData struct {
Values []types.Datum
}

func getCount(ctx sessionctx.Context, sql string) (int64, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql)
func getCount(exec sqlexec.RestrictedSQLExecutor, stmt ast.StmtNode, snapshot uint64) (int64, error) {
rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot))
if err != nil {
return 0, errors.Trace(err)
}
if len(rows) != 1 {
return 0, errors.Errorf("can not get count, sql %s result rows %d", sql, len(rows))
return 0, errors.Errorf("can not get count, rows count = %d", len(rows))
}
return rows[0].GetInt64(0), nil
}
Expand All @@ -313,14 +313,34 @@ const (
// otherwise it returns an error and the corresponding index's offset.
func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices []string) (byte, int, error) {
// Add `` for some names like `table name`.
sql := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX()", dbName, tableName)
tblCnt, err := getCount(ctx, sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName)
if err != nil {
return 0, 0, errors.Trace(err)
}

var snapshot uint64
txn, err := ctx.Txn(false)
if err != nil {
return 0, 0, err
}
if txn.Valid() {
snapshot = txn.StartTS()
}
if ctx.GetSessionVars().SnapshotTS != 0 {
snapshot = ctx.GetSessionVars().SnapshotTS
}

tblCnt, err := getCount(exec, stmt, snapshot)
if err != nil {
return 0, 0, errors.Trace(err)
}
for i, idx := range indices {
sql = fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX(`%s`)", dbName, tableName, idx)
idxCnt, err := getCount(ctx, sql)
stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx)
if err != nil {
return 0, i, errors.Trace(err)
}
idxCnt, err := getCount(exec, stmt, snapshot)
if err != nil {
return 0, i, errors.Trace(err)
}
Expand Down
40 changes: 25 additions & 15 deletions util/gcutil/gcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
package gcutil

import (
"fmt"
"context"

"github.com/pingcap/errors"
"github.com/pingcap/parser/model"
Expand All @@ -25,18 +25,20 @@ import (
)

const (
selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name='%s'`
insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s')
ON DUPLICATE KEY
UPDATE variable_value = '%[2]s', comment = '%[3]s'`
insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?)
ON DUPLICATE KEY UPDATE variable_value = %?, comment = %?`
selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name=%?`
)

// CheckGCEnable is use to check whether GC is enable.
func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) {
sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_enable")
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
if err != nil {
return false, errors.Trace(err)
stmt, err1 := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_enable")
if err1 != nil {
return false, errors.Trace(err1)
}
rows, _, err2 := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt)
if err1 != nil {
return false, errors.Trace(err2)
}
if len(rows) != 1 {
return false, errors.New("can not get 'tikv_gc_enable'")
Expand All @@ -46,15 +48,19 @@ func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) {

// DisableGC will disable GC enable variable.
func DisableGC(ctx sessionctx.Context) error {
sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status")
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status", "false", "Current GC enable status")
if err == nil {
_, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt)
}
return errors.Trace(err)
}

// EnableGC will enable GC enable variable.
func EnableGC(ctx sessionctx.Context) error {
sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status")
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status", "true", "Current GC enable status")
if err == nil {
_, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt)
}
return errors.Trace(err)
}

Expand All @@ -80,8 +86,12 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error {

// GetGCSafePoint loads GC safe point time from mysql.tidb.
func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) {
sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_safe_point")
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point")
if err != nil {
return 0, errors.Trace(err)
}
rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt)
if err != nil {
return 0, errors.Trace(err)
}
Expand Down
Loading

0 comments on commit 7daa3ec

Please sign in to comment.