diff --git a/executor/brie.go b/executor/brie.go index 0d14265a8fd0e..492a3d68d8121 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -404,6 +404,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error) // Execute implements glue.Session func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { + // FIXME: br relies on a deprecated API, it may be unsafe _, err := gs.se.(sqlexec.SQLExecutor).Execute(ctx, sql) return err } diff --git a/executor/ddl.go b/executor/ddl.go index ae2d685a65ccf..ef903bd4e377d 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -309,8 +309,12 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx zap.String("database", fullti.Schema.O), zap.String("table", fullti.Name.O), ) - sql := fmt.Sprintf("admin check table `%s`.`%s`", fullti.Schema.O, fullti.Name.O) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/grant.go b/executor/grant.go index 097e2d5dcc349..86b279727aaa5 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -16,7 +16,6 @@ package executor import ( "context" "encoding/json" - "fmt" "strings" "github.com/pingcap/errors" @@ -106,7 +105,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { if !isCommit { - _, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback") + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback") if err != nil { logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err)) } @@ -114,7 +113,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { e.releaseSysSession(internalSession) }() - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin") if err != nil { return err } @@ -132,9 +131,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { if !ok { return errors.Trace(ErrPasswordFormat) } - user := fmt.Sprintf(`('%s', '%s', '%s')`, user.User.Hostname, user.User.Username, pwd) - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string) VALUES %s;`, mysql.SystemDB, mysql.UserTable, user) - _, err := internalSession.(sqlexec.SQLExecutor).Execute(ctx, sql) + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx, `INSERT INTO %n.%n (Host, User, authentication_string) VALUES (%?, %?, %?);`, mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd) if err != nil { return err } @@ -193,7 +190,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } } - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit") if err != nil { return err } @@ -274,29 +271,25 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast // initGlobalPrivEntry inserts a new row into mysql.DB with empty privilege. func initGlobalPrivEntry(ctx sessionctx.Context, user string, host string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, PRIV) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, PRIV) VALUES (%?, %?, %?)`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") return err } // initDBPrivEntry inserts a new row into mysql.DB with empty privilege. func initDBPrivEntry(ctx sessionctx.Context, user string, host string, db string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.DBTable, host, user, db) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB) VALUES (%?, %?, %?)`, mysql.SystemDB, mysql.DBTable, host, user, db) return err } // initTablePrivEntry inserts a new row into mysql.Tables_priv with empty privilege. func initTablePrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES ('%s', '%s', '%s', '%s', '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES (%?, %?, %?, %?, '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl) return err } // initColumnPrivEntry inserts a new row into mysql.Columns_priv with empty privilege. func initColumnPrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string, col string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Column_name, Column_priv) VALUES ('%s', '%s', '%s', '%s', '%s', '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB, Table_name, Column_name, Column_priv) VALUES (%?, %?, %?, %?, %?, '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col) return err } @@ -309,8 +302,7 @@ func (e *GrantExec) grantGlobalPriv(ctx sessionctx.Context, user *ast.UserSpec) if err != nil { return errors.Trace(err) } - sql := fmt.Sprintf(`UPDATE %s.%s SET PRIV = '%s' WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) - _, err = ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err = ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `UPDATE %n.%n SET PRIV=%? WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) return err } @@ -415,12 +407,16 @@ func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec, int if priv.Priv == 0 { return nil } - asgns, err := composeGlobalPrivUpdate(priv.Priv, "Y") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, `UPDATE %n.%n SET `, mysql.SystemDB, mysql.UserTable) + err := composeGlobalPrivUpdate(sql, priv.Priv, "Y") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user.User.Username, user.User.Hostname) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, ` WHERE User=%? AND Host=%?`, user.User.Username, user.User.Hostname) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -430,12 +426,16 @@ func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec, interna if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB } - asgns, err := composeDBPrivUpdate(priv.Priv, "Y") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable) + err := composeDBPrivUpdate(sql, priv.Priv, "Y") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, user.User.Username, user.User.Hostname, dbName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", user.User.Username, user.User.Hostname, dbName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -446,12 +446,16 @@ func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec, inte dbName = e.ctx.GetSessionVars().CurrentDB } tblName := e.Level.TableName - asgns, err := composeTablePrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable) + err := composeTablePrivUpdateForGrant(internalSession, sql, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tblName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user.User.Username, user.User.Hostname, dbName, tblName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -467,12 +471,16 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, int if col == nil { return errors.Errorf("Unknown column: %s", c) } - asgns, err := composeColumnPrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable) + err := composeColumnPrivUpdateForGrant(internalSession, sql, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) if err != nil { return err } @@ -481,178 +489,143 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, int } // composeGlobalPrivUpdate composes update stmt assignment list string for global scope privilege update. -func composeGlobalPrivUpdate(priv mysql.PrivilegeType, value string) (string, error) { - if priv == mysql.AllPriv { - strs := make([]string, 0, len(mysql.Priv2UserCol)) - for _, v := range mysql.AllGlobalPrivs { - strs = append(strs, fmt.Sprintf(`%s='%s'`, mysql.Priv2UserCol[v], value)) +func composeGlobalPrivUpdate(sql *strings.Builder, priv mysql.PrivilegeType, value string) error { + if priv != mysql.AllPriv { + col, ok := mysql.Priv2UserCol[priv] + if !ok { + return errors.Errorf("Unknown priv: %v", priv) } - return strings.Join(strs, ", "), nil + sqlexec.MustFormatSQL(sql, "%n=%?", col, value) + return nil } - col, ok := mysql.Priv2UserCol[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + + for i, v := range mysql.AllGlobalPrivs { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + + k, ok := mysql.Priv2UserCol[v] + if !ok { + return errors.Errorf("Unknown priv %v", priv) + } + + sqlexec.MustFormatSQL(sql, "%n=%?", k, value) } - return fmt.Sprintf(`%s='%s'`, col, value), nil + return nil } // composeDBPrivUpdate composes update stmt assignment list for db scope privilege update. -func composeDBPrivUpdate(priv mysql.PrivilegeType, value string) (string, error) { - if priv == mysql.AllPriv { - strs := make([]string, 0, len(mysql.AllDBPrivs)) - for _, p := range mysql.AllDBPrivs { - v, ok := mysql.Priv2UserCol[p] - if !ok { - return "", errors.Errorf("Unknown db privilege %v", priv) - } - strs = append(strs, fmt.Sprintf(`%s='%s'`, v, value)) +func composeDBPrivUpdate(sql *strings.Builder, priv mysql.PrivilegeType, value string) error { + if priv != mysql.AllPriv { + col, ok := mysql.Priv2UserCol[priv] + if !ok { + return errors.Errorf("Unknown priv: %v", priv) } - return strings.Join(strs, ", "), nil - } - col, ok := mysql.Priv2UserCol[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + sqlexec.MustFormatSQL(sql, "%n=%?", col, value) + return nil } - return fmt.Sprintf(`%s='%s'`, col, value), nil -} -// composeTablePrivUpdateForGrant composes update stmt assignment list for table scope privilege update. -func composeTablePrivUpdateForGrant(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string) (string, error) { - var newTablePriv, newColumnPriv string - if priv == mysql.AllPriv { - for _, p := range mysql.AllTablePrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown table privilege %v", p) - } - newTablePriv = addToSet(newTablePriv, v) - } - for _, p := range mysql.AllColumnPrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown column privilege %v", p) - } - newColumnPriv = addToSet(newColumnPriv, v) - } - } else { - currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) - if err != nil { - return "", err + for i, p := range mysql.AllDBPrivs { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") } - p, ok := mysql.Priv2SetStr[priv] + + v, ok := mysql.Priv2UserCol[p] if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return errors.Errorf("Unknown priv %v", priv) } - newTablePriv = addToSet(currTablePriv, p) - for _, cp := range mysql.AllColumnPrivs { - if priv == cp { - newColumnPriv = addToSet(currColumnPriv, p) - break - } - } + sqlexec.MustFormatSQL(sql, "%n=%?", v, value) } - return fmt.Sprintf(`Table_priv='%s', Column_priv='%s', Grantor='%s'`, newTablePriv, newColumnPriv, ctx.GetSessionVars().User), nil + return nil } -func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string) (string, error) { - var newTablePriv, newColumnPriv string - if priv == mysql.AllPriv { - newTablePriv = "" - newColumnPriv = "" - } else { +func privUpdateForGrant(cur []string, priv mysql.PrivilegeType) ([]string, error) { + p, ok := mysql.Priv2SetStr[priv] + if !ok { + return nil, errors.Errorf("Unknown priv: %v", priv) + } + cur = addToSet(cur, p) + return cur, nil +} + +// composeTablePrivUpdateForGrant composes update stmt assignment list for table scope privilege update. +func composeTablePrivUpdateForGrant(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error { + var newTablePriv, newColumnPriv []string + var tblPrivs, colPrivs []mysql.PrivilegeType + if priv != mysql.AllPriv { currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newTablePriv = deleteFromSet(currTablePriv, p) - + newTablePriv = setFromString(currTablePriv) + newColumnPriv = setFromString(currColumnPriv) + tblPrivs = []mysql.PrivilegeType{priv} for _, cp := range mysql.AllColumnPrivs { - if priv == cp { - newColumnPriv = deleteFromSet(currColumnPriv, p) + // in case it is not a column priv + if cp == priv { + colPrivs = []mysql.PrivilegeType{priv} break } } + } else { + tblPrivs = mysql.AllTablePrivs + colPrivs = mysql.AllColumnPrivs } - return fmt.Sprintf(`Table_priv='%s', Column_priv='%s', Grantor='%s'`, newTablePriv, newColumnPriv, ctx.GetSessionVars().User), nil -} -// addToSet add a value to the set, e.g: -// addToSet("Select,Insert", "Update") returns "Select,Insert,Update". -func addToSet(set string, value string) string { - if set == "" { - return value + var err error + for _, p := range tblPrivs { + newTablePriv, err = privUpdateForGrant(newTablePriv, p) + if err != nil { + return err + } } - return fmt.Sprintf("%s,%s", set, value) -} -// deleteFromSet delete the value from the set, e.g: -// deleteFromSet("Select,Insert,Update", "Update") returns "Select,Insert". -func deleteFromSet(set string, value string) string { - sets := strings.Split(set, ",") - res := make([]string, 0, len(sets)) - for _, v := range sets { - if v != value { - res = append(res, v) + for _, p := range colPrivs { + newColumnPriv, err = privUpdateForGrant(newColumnPriv, p) + if err != nil { + return err } } - return strings.Join(res, ",") + + sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String()) + return nil } // composeColumnPrivUpdateForGrant composes update stmt assignment list for column scope privilege update. -func composeColumnPrivUpdateForGrant(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) (string, error) { - newColumnPriv := "" - if priv == mysql.AllPriv { - for _, p := range mysql.AllColumnPrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown column privilege %v", p) - } - newColumnPriv = addToSet(newColumnPriv, v) - } - } else { +func composeColumnPrivUpdateForGrant(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error { + var newColumnPriv []string + var colPrivs []mysql.PrivilegeType + if priv != mysql.AllPriv { currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newColumnPriv = addToSet(currColumnPriv, p) + newColumnPriv = setFromString(currColumnPriv) + colPrivs = []mysql.PrivilegeType{priv} + } else { + colPrivs = mysql.AllColumnPrivs } - return fmt.Sprintf(`Column_priv='%s'`, newColumnPriv), nil -} -func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) (string, error) { - newColumnPriv := "" - if priv == mysql.AllPriv { - newColumnPriv = "" - } else { - currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) + var err error + for _, p := range colPrivs { + newColumnPriv, err = privUpdateForGrant(newColumnPriv, p) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newColumnPriv = deleteFromSet(currColumnPriv, p) } - return fmt.Sprintf(`Column_priv='%s'`, newColumnPriv), nil + + sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ",")) + return nil } // recordExists is a helper function to check if the sql returns any row. -func recordExists(ctx sessionctx.Context, sql string) (bool, error) { - recordSets, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) +func recordExists(ctx sessionctx.Context, sql string, args ...interface{}) (bool, error) { + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql, args...) if err != nil { return false, err } - rows, _, err := getRowsAndFields(ctx, recordSets) + rows, _, err := getRowsAndFields(ctx, rs) if err != nil { return false, err } @@ -661,43 +634,35 @@ func recordExists(ctx sessionctx.Context, sql string) (bool, error) { // globalPrivEntryExists checks if there is an entry with key user-host in mysql.global_priv. func globalPrivEntryExists(ctx sessionctx.Context, name string, host string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) } // dbUserExists checks if there is an entry with key user-host-db in mysql.DB. func dbUserExists(ctx sessionctx.Context, name string, host string, db string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, name, host, db) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%?;`, mysql.SystemDB, mysql.DBTable, name, host, db) } // tableUserExists checks if there is an entry with key user-host-db-tbl in mysql.Tables_priv. func tableUserExists(ctx sessionctx.Context, name string, host string, db string, tbl string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?;`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) } // columnPrivEntryExists checks if there is an entry with key user-host-db-tbl-col in mysql.Columns_priv. func columnPrivEntryExists(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?;`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) } // getTablePriv gets current table scope privilege set from mysql.Tables_priv. // Return Table_priv and Column_priv. func getTablePriv(ctx sessionctx.Context, name string, host string, db string, tbl string) (string, string, error) { - sql := fmt.Sprintf(`SELECT Table_priv, Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) - rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `SELECT Table_priv, Column_priv FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) if err != nil { return "", "", err } - if len(rs) < 1 { - return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl) - } var tPriv, cPriv string rows, fields, err := getRowsAndFields(ctx, rs) if err != nil { - return "", "", err + return "", "", errors.Errorf("get table privilege fail for %s %s %s %s: %v", name, host, db, tbl, err) } if len(rows) < 1 { return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl) @@ -717,17 +682,13 @@ func getTablePriv(ctx sessionctx.Context, name string, host string, db string, t // getColumnPriv gets current column scope privilege set from mysql.Columns_priv. // Return Column_priv. func getColumnPriv(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (string, error) { - sql := fmt.Sprintf(`SELECT Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) - rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `SELECT Column_priv FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?;`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) if err != nil { return "", err } - if len(rs) < 1 { - return "", errors.Errorf("get column privilege fail for %s %s %s %s", name, host, db, tbl) - } rows, fields, err := getRowsAndFields(ctx, rs) if err != nil { - return "", err + return "", errors.Errorf("get column privilege fail for %s %s %s %s: %s", name, host, db, tbl, err) } if len(rows) < 1 { return "", errors.Errorf("get column privilege fail for %s %s %s %s %s", name, host, db, tbl, col) @@ -757,27 +718,18 @@ func getTargetSchemaAndTable(ctx sessionctx.Context, dbName, tableName string, i } // getRowsAndFields is used to extract rows from record sets. -func getRowsAndFields(ctx sessionctx.Context, recordSets []sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) { - var ( - rows []chunk.Row - fields []*ast.ResultField - ) - - for i, rs := range recordSets { - tmp, err := getRowFromRecordSet(context.Background(), ctx, rs) - if err != nil { - return nil, nil, err - } - if err = rs.Close(); err != nil { - return nil, nil, err - } - - if i == 0 { - rows = tmp - fields = rs.Fields() - } +func getRowsAndFields(ctx sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) { + if rs == nil { + return nil, nil, errors.Errorf("nil recordset") + } + rows, err := getRowFromRecordSet(context.Background(), ctx, rs) + if err != nil { + return nil, nil, err + } + if err = rs.Close(); err != nil { + return nil, nil, err } - return rows, fields, nil + return rows, rs.Fields(), nil } func getRowFromRecordSet(ctx context.Context, se sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index 72880529ad08f..23cacc1558e76 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -154,10 +154,16 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex } func getRowCountAllTable(ctx sessionctx.Context) (map[int64]uint64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, count from mysql.stats_meta") + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, count from mysql.stats_meta") if err != nil { return nil, err } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return nil, err + } + rowCountMap := make(map[int64]uint64, len(rows)) for _, row := range rows { tableID := row.GetInt64(0) @@ -173,10 +179,16 @@ type tableHistID struct { } func getColLengthAllTables(ctx sessionctx.Context) (map[tableHistID]uint64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") if err != nil { return nil, err } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return nil, err + } + colLengthMap := make(map[tableHistID]uint64, len(rows)) for _, row := range rows { tableID := row.GetInt64(0) diff --git a/executor/inspection_profile.go b/executor/inspection_profile.go index f243db364f0d8..f15dd6ef5e6ff 100644 --- a/executor/inspection_profile.go +++ b/executor/inspection_profile.go @@ -161,7 +161,13 @@ func (n *metricNode) getLabelValue(label string) *metricValue { } func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error { - rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(context.Background(), query) + exec := pb.sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), query) + if err != nil { + return err + } + + rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/inspection_result.go b/executor/inspection_result.go index 4d8bd69ca3edb..5bb858a8224e9 100644 --- a/executor/inspection_result.go +++ b/executor/inspection_result.go @@ -141,8 +141,12 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct // Get cluster info. e.instanceToStatusAddress = make(map[string]string) e.statusToInstanceAddress = make(map[string]string) - sql := "select instance,status_address from information_schema.cluster_info;" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select instance,status_address from information_schema.cluster_info;") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("get cluster info failed: %v", err)) } @@ -247,16 +251,22 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C "storage.data-dir", "storage.block-cache.capacity", } - sql := fmt.Sprintf("select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in ('%s') group by type, `key` having c > 1", - strings.Join(ignoreConfigKey, "','")) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) } generateDetail := func(tp, item string) string { - query := fmt.Sprintf("select value, instance from information_schema.cluster_config where type='%s' and `key`='%s';", tp, item) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, query) + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) return fmt.Sprintf("the cluster has different config value of %[2]s, execute the sql to see more detail: select * from information_schema.cluster_config where type='%[1]s' and `key`='%[2]s'", @@ -318,13 +328,18 @@ func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionct } var results []inspectionResult + var rows []chunk.Row + sql := new(strings.Builder) + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, cas := range cases { if !filter.enable(cas.key) { continue } - sql := fmt.Sprintf("select instance from information_schema.cluster_config where type = '%s' and `key` = '%s' and value = '%s'", - cas.tp, cas.key, cas.value) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + stmt, err := exec.ParseWithParams(ctx, "select instance from information_schema.cluster_config where type = %? and %n = %? and value = %?", cas.tp, "key", cas.key, cas.value) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -350,8 +365,12 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct if !filter.enable(item) { return nil } - sql := "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -375,8 +394,10 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct ipToCount[ip]++ } - sql = "select instance, value from metrics_schema.node_total_memory where time=now()" - rows, _, err = sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err = exec.ParseWithParams(ctx, "select instance, value from metrics_schema.node_total_memory where time=now()") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -438,9 +459,13 @@ func (configInspection) convertReadableSizeToByteSize(sizeStr string) (uint64, e } func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + exec := sctx.(sqlexec.RestrictedSQLExecutor) + var rows []chunk.Row // check the configuration consistent - sql := "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check version consistency failed: %v", err)) } @@ -594,6 +619,9 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) for _, rule := range rules { if filter.enable(rule.item) { def, found := infoschema.MetricTableMap[rule.tbl] @@ -601,9 +629,13 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", rule.tbl)) continue } - sql := fmt.Sprintf("select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", + sql.Reset() + fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -649,10 +681,16 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se return nil } condition := filter.timeRange.Condition() - sql := fmt.Sprintf(`select t1.job,t1.instance, t2.min_time from + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) + fmt.Fprintf(sql, `select t1.job,t1.instance, t2.min_time from (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -675,8 +713,12 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se results = append(results, result) } // Check from log. - sql = fmt.Sprintf("select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) - rows, _, err = sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) + stmt, err = exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -790,24 +832,30 @@ func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) for _, rule := range rules { if !filter.enable(rule.item) { continue } - var sql string + sql.Reset() if len(rule.configKey) > 0 { - sql = fmt.Sprintf("select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from "+ - "(select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join "+ - "(select instance, value from information_schema.cluster_config where type='tikv' and `key` = '%[3]s') as t2 join "+ - "(select instance,status_address from information_schema.cluster_info where type='tikv') as t3 "+ - "on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)", rule.component, rule.threshold, rule.configKey, condition) + fmt.Fprintf(sql, `select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from + (select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join + (select instance, value from information_schema.cluster_config where type='tikv' and %[5]s = '%[3]s') as t2 join + (select instance,status_address from information_schema.cluster_info where type='tikv') as t3 + on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)`, rule.component, rule.threshold, rule.configKey, condition, "`key`") } else { - sql = fmt.Sprintf("select t1.instance, t1.cpu, %[2]f from "+ - "(select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 "+ - "where t1.cpu > %[2]f;", rule.component, rule.threshold, condition) + fmt.Fprintf(sql, `select t1.instance, t1.cpu, %[2]f from + (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 + where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) + } + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -957,11 +1005,13 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + sql := new(strings.Builder) + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { if !filter.enable(rule.item) { continue } - var sql string cond := condition if len(rule.condition) > 0 { cond = fmt.Sprintf("%s and %s", cond, rule.condition) @@ -969,12 +1019,16 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess if rule.factor == 0 { rule.factor = 1 } + sql.Reset() if rule.isMin { - sql = fmt.Sprintf("select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) + fmt.Fprintf(sql, "select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) } else { - sql = fmt.Sprintf("select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) + fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) + } + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1150,12 +1204,17 @@ func (thresholdCheckInspection) inspectThreshold3(ctx context.Context, sctx sess func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter, rules []ruleChecker) []inspectionResult { var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { if !filter.enable(rule.getItem()) { continue } sql := rule.genSQL(filter.timeRange) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, sql) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1170,8 +1229,15 @@ func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionF func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { condition := filter.timeRange.Condition() threshold := 50.0 - sql := fmt.Sprintf(`select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql := new(strings.Builder) + fmt.Fprintf(sql, `select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) return nil @@ -1179,12 +1245,18 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx var results []inspectionResult for _, row := range rows { address := row.GetString(0) - sql := fmt.Sprintf(`select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) - subRows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) + var subRows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + subRows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue } + lastValue := float64(0) for i, subRows := range subRows { v := subRows.GetFloat64(1) diff --git a/executor/inspection_summary.go b/executor/inspection_summary.go index 37aef042a62f9..2a01cab6dc402 100644 --- a/executor/inspection_summary.go +++ b/executor/inspection_summary.go @@ -458,7 +458,12 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc sql = fmt.Sprintf("select avg(value),min(value),max(value) from `%s`.`%s` %s", util.MetricSchemaName.L, name, cond) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/metrics_reader.go b/executor/metrics_reader.go index 91cb8159bfc27..ff582c23b9935 100644 --- a/executor/metrics_reader.go +++ b/executor/metrics_reader.go @@ -189,7 +189,7 @@ type MetricsSummaryRetriever struct { retrieved bool } -func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { +func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { if e.retrieved || e.extractor.SkipRequest { return nil, nil } @@ -229,7 +229,12 @@ func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Co name, util.MetricSchemaName.L, condition) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } @@ -306,7 +311,12 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%s`.`%s` %s", util.MetricSchemaName.L, name, cond) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/opt_rule_blacklist.go b/executor/opt_rule_blacklist.go index 8bb55c16f52e5..76cdc74ea1d11 100644 --- a/executor/opt_rule_blacklist.go +++ b/executor/opt_rule_blacklist.go @@ -35,8 +35,12 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e // LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist. func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) { - sql := "select HIGH_PRIORITY name from mysql.opt_rule_blacklist" - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") + if err != nil { + return err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/prepared.go b/executor/prepared.go index f57a1806b1ed9..9f8a427b84afa 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -116,6 +116,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { err error ) if sqlParser, ok := e.ctx.(sqlexec.SQLParser); ok { + // FIXME: ok... yet another parse API, may need some api interface clean. stmts, err = sqlParser.ParseSQL(e.sqlText, charset, collation) } else { p := parser.New() diff --git a/executor/reload_expr_pushdown_blacklist.go b/executor/reload_expr_pushdown_blacklist.go index 3d0752e08463d..5783438813954 100644 --- a/executor/reload_expr_pushdown_blacklist.go +++ b/executor/reload_expr_pushdown_blacklist.go @@ -37,8 +37,12 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu // LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist. func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) { - sql := "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist" - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") + if err != nil { + return err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/revoke.go b/executor/revoke.go index 7e51aa8ac82a4..fb722c89acf40 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -15,7 +15,7 @@ package executor import ( "context" - "fmt" + "strings" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" @@ -73,7 +73,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { if !isCommit { - _, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback") + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback") if err != nil { logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err)) } @@ -81,7 +81,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { e.releaseSysSession(internalSession) }() - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin") if err != nil { return err } @@ -103,7 +103,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { } } - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit") if err != nil { return err } @@ -166,12 +166,15 @@ func (e *RevokeExec) revokePriv(internalSession sessionctx.Context, priv *ast.Pr } func (e *RevokeExec) revokeGlobalPriv(internalSession sessionctx.Context, priv *ast.PrivElem, user, host string) error { - asgns, err := composeGlobalPrivUpdate(priv.Priv, "N") + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable) + err := composeGlobalPrivUpdate(sql, priv.Priv, "N") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user, host) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%?", user, host) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -180,12 +183,16 @@ func (e *RevokeExec) revokeDBPriv(internalSession sessionctx.Context, priv *ast. if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB } - asgns, err := composeDBPrivUpdate(priv.Priv, "N") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable) + err := composeDBPrivUpdate(sql, priv.Priv, "N") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, userName, host, dbName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", userName, host, dbName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -194,12 +201,16 @@ func (e *RevokeExec) revokeTablePriv(internalSession sessionctx.Context, priv *a if err != nil { return err } - asgns, err := composeTablePrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable) + err = composeTablePrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user, host, dbName, tbl.Meta().Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user, host, dbName, tbl.Meta().Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -208,20 +219,80 @@ func (e *RevokeExec) revokeColumnPriv(internalSession sessionctx.Context, priv * if err != nil { return err } + sql := new(strings.Builder) for _, c := range priv.Cols { col := table.FindCol(tbl.Cols(), c.Name.L) if col == nil { return errors.Errorf("Unknown column: %s", c) } - asgns, err := composeColumnPrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O) + + sql.Reset() + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable) + err = composeColumnPrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O) + if err != nil { + return err + } + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user, host, dbName, tbl.Meta().Name.O, col.Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) + if err != nil { + return err + } + } + return nil +} + +func privUpdateForRevoke(cur []string, priv mysql.PrivilegeType) ([]string, error) { + p, ok := mysql.Priv2SetStr[priv] + if !ok { + return nil, errors.Errorf("Unknown priv: %v", priv) + } + cur = deleteFromSet(cur, p) + return cur, nil +} + +func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error { + var newTablePriv, newColumnPriv []string + + if priv != mysql.AllPriv { + currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) + if err != nil { + return err + } + + newTablePriv = setFromString(currTablePriv) + newTablePriv, err = privUpdateForRevoke(newTablePriv, priv) + if err != nil { + return err + } + + newColumnPriv = setFromString(currColumnPriv) + newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv) + if err != nil { + return err + } + } + + sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String()) + return nil +} + +func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error { + var newColumnPriv []string + + if priv != mysql.AllPriv { + currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user, host, dbName, tbl.Meta().Name.O, col.Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + + newColumnPriv = setFromString(currColumnPriv) + newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv) if err != nil { return err } } + + sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ",")) return nil } diff --git a/executor/show.go b/executor/show.go index bb856ac7657e0..8d4095756b137 100644 --- a/executor/show.go +++ b/executor/show.go @@ -284,12 +284,17 @@ func (e *ShowExec) fetchShowBind() error { } func (e *ShowExec) fetchShowEngines() error { - sql := `SELECT * FROM information_schema.engines` - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM information_schema.engines`) if err != nil { return errors.Trace(err) } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return errors.Trace(err) + } + for _, row := range rows { e.result.AppendRow(row) } @@ -410,16 +415,32 @@ func (e *ShowExec) fetchShowTableStatus() error { return ErrBadDB.GenWithStackByArgs(e.DBName) } - sql := fmt.Sprintf(`SELECT - table_name, engine, version, row_format, table_rows, - avg_row_length, data_length, max_data_length, index_length, - data_free, auto_increment, create_time, update_time, check_time, - table_collation, IFNULL(checksum,''), create_options, table_comment - FROM information_schema.tables - WHERE table_schema='%s' ORDER BY table_name`, e.DBName) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT + table_name, engine, version, row_format, table_rows, + avg_row_length, data_length, max_data_length, index_length, + data_free, auto_increment, create_time, update_time, check_time, + table_collation, IFNULL(checksum,''), create_options, table_comment + FROM information_schema.tables + WHERE table_schema=%? ORDER BY table_name`, e.DBName.L) + if err != nil { + return errors.Trace(err) + } + var snapshot uint64 + txn, err := e.ctx.Txn(false) + if err != nil { + return errors.Trace(err) + } + if txn.Valid() { + snapshot = txn.StartTS() + } + if e.ctx.GetSessionVars().SnapshotTS != 0 { + snapshot = e.ctx.GetSessionVars().SnapshotTS + } + + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) if err != nil { return errors.Trace(err) } @@ -1199,22 +1220,32 @@ func (e *ShowExec) fetchShowCreateUser() error { } } - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, - mysql.SystemDB, mysql.UserTable, userName, hostName) - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, hostName) + if err != nil { + return errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return errors.Trace(err) } + if len(rows) == 0 { + // FIXME: the error returned is not escaped safely return ErrCannotUser.GenWithStackByArgs("SHOW CREATE USER", fmt.Sprintf("'%s'@'%s'", e.User.Username, e.User.Hostname)) } - sql = fmt.Sprintf(`SELECT PRIV FROM %s.%s WHERE User='%s' AND Host='%s'`, - mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) - rows, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + + stmt, err = exec.ParseWithParams(context.TODO(), `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) if err != nil { return errors.Trace(err) } + rows, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return errors.Trace(err) + } + require := "NONE" if len(rows) == 1 { privData := rows[0].GetString(0) @@ -1225,6 +1256,7 @@ func (e *ShowExec) fetchShowCreateUser() error { } require = privValue.RequireStr() } + // FIXME: the returned string is not escaped safely showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE %s PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname), require) e.appendRow([]interface{}{showStr}) diff --git a/executor/simple.go b/executor/simple.go index 2f1dca0cb982d..9a91191035cbd 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -80,7 +80,7 @@ func (e *baseExecutor) releaseSysSession(ctx sessionctx.Context) { } dom := domain.GetDomain(e.ctx) sysSessionPool := dom.SysSessionPool() - if _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "rollback"); err != nil { ctx.(pools.Resource).Close() return } @@ -151,23 +151,25 @@ func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, u := range s.UserList { if u.Hostname == "" { u.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", u.Username, u.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", u.Username, u.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -199,42 +201,45 @@ func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { if user.Hostname == "" { user.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } for _, role := range s.RoleList { - sql := fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles values('%s', '%s', '%s', '%s');", user.Hostname, user.Username, role.Hostname, role.Username) checker := privilege.GetPrivilegeManager(e.ctx) ok := checker.FindEdge(e.ctx, role, user) if ok { - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles values(%?, %?, %?, %?);", user.Hostname, user.Username, role.Hostname, role.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } else { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -256,31 +261,34 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { if user.Hostname == "" { user.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ - "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST=%? AND TO_USER=%?;", user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -288,29 +296,10 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (err error) { checker := privilege.GetPrivilegeManager(e.ctx) - user, sql := s.UserList[0], "" + user := s.UserList[0] if user.Hostname == "" { user.Hostname = "%" } - switch s.SetRoleOpt { - case ast.SetRoleNone: - sql = fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - case ast.SetRoleAll: - sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ - "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) - case ast.SetRoleRegular: - sql = "INSERT IGNORE INTO mysql.default_roles values" - for i, role := range s.RoleList { - ok := checker.FindEdge(e.ctx, role, user) - if !ok { - return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) - } - sql += fmt.Sprintf("('%s', '%s', '%s', '%s')", user.Hostname, user.Username, role.Hostname, role.Username) - if i != len(s.RoleList)-1 { - sql += "," - } - } - } restrictedCtx, err := e.getSysSession() if err != nil { @@ -319,27 +308,48 @@ func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (er defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } - deleteSQL := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), deleteSQL); err != nil { + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + switch s.SetRoleOpt { + case ast.SetRoleNone: + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + case ast.SetRoleAll: + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST=%? AND TO_USER=%?;", user.Hostname, user.Username) + case ast.SetRoleRegular: + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles values") + for i, role := range s.RoleList { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + ok := checker.FindEdge(e.ctx, role, user) + if !ok { + return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) + } + sqlexec.MustFormatSQL(sql, "(%?, %?, %?, %?)", user.Hostname, user.Username, role.Hostname, role.Username) + } + } + + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -595,16 +605,17 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) // begin a transaction to insert role graph edges. - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return errors.Trace(err) } + sql := new(strings.Builder) for _, user := range s.Users { exists, err := userExists(e.ctx, user.Username, user.Hostname) if err != nil { return errors.Trace(err) } if !exists { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", user.String()) @@ -613,23 +624,26 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { if role.Hostname == "" { role.Hostname = "%" } - sql := fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE FROM_HOST='%s' and FROM_USER='%s' and TO_HOST='%s' and TO_USER='%s'`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE IGNORE FROM %n.%n WHERE FROM_HOST=%? and FROM_USER=%? and TO_HOST=%? and TO_USER=%?`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } - sql = fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE DEFAULT_ROLE_HOST='%s' and DEFAULT_ROLE_USER='%s' and HOST='%s' and USER='%s'`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE IGNORE FROM %n.%n WHERE DEFAULT_ROLE_HOST=%? and DEFAULT_ROLE_USER=%? and HOST=%? and USER=%?`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -687,9 +701,18 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm return err } - users := make([]string, 0, len(s.Specs)) - privs := make([]string, 0, len(s.Specs)) + sql := new(strings.Builder) + if s.IsCreateRole { + sqlexec.MustFormatSQL(sql, `INSERT INTO %n.%n (Host, User, authentication_string, Account_locked) VALUES `, mysql.SystemDB, mysql.UserTable) + } else { + sqlexec.MustFormatSQL(sql, `INSERT INTO %n.%n (Host, User, authentication_string) VALUES `, mysql.SystemDB, mysql.UserTable) + } + + users := make([]*auth.UserIdentity, 0, len(s.Specs)) for _, spec := range s.Specs { + if len(users) > 0 { + sqlexec.MustFormatSQL(sql, ",") + } exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) if err1 != nil { return err1 @@ -710,26 +733,17 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if !ok { return errors.Trace(ErrPasswordFormat) } - user := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, pwd) if s.IsCreateRole { - user = fmt.Sprintf(`('%s', '%s', '%s', 'Y')`, spec.User.Hostname, spec.User.Username, pwd) - } - users = append(users, user) - - if len(privData) != 0 { - priv := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, hack.String(privData)) - privs = append(privs, priv) + sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, "Y") + } else { + sqlexec.MustFormatSQL(sql, `(%?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd) } + users = append(users, spec.User) } if len(users) == 0 { return nil } - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) - if s.IsCreateRole { - sql = fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string, Account_locked) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) - } - restrictedCtx, err := e.getSysSession() if err != nil { return err @@ -737,27 +751,34 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return errors.Trace(err) } - _, err = sqlExecutor.Execute(context.Background(), sql) + _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()) if err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if len(privs) != 0 { - sql = fmt.Sprintf("INSERT IGNORE INTO %s.%s (Host, User, Priv) VALUES %s", mysql.SystemDB, mysql.GlobalPrivTable, strings.Join(privs, ", ")) - _, err = sqlExecutor.Execute(context.Background(), sql) + if len(privData) != 0 { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO %n.%n (Host, User, Priv) VALUES ", mysql.SystemDB, mysql.GlobalPrivTable) + for i, user := range users { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + sqlexec.MustFormatSQL(sql, `(%?, %?, %?)`, user.Hostname, user.Username, string(hack.String(privData))) + } + _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()) if err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return errors.Trace(err) } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -806,17 +827,22 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { if !ok { return errors.Trace(ErrPasswordFormat) } - sql := fmt.Sprintf(`UPDATE %s.%s SET authentication_string = '%s' WHERE Host = '%s' and User = '%s';`, - mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } if len(privData) > 0 { - sql = fmt.Sprintf("INSERT INTO %s.%s (Host, User, Priv) VALUES ('%s','%s','%s') ON DUPLICATE KEY UPDATE Priv = values(Priv)", - mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, hack.String(privData)) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err = exec.ParseWithParams(context.TODO(), "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } @@ -880,23 +906,25 @@ func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) // begin a transaction to insert role graph edges. - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.Users { for _, role := range s.Roles { - sql := fmt.Sprintf(`INSERT IGNORE INTO %s.%s (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ('%s','%s','%s','%s')`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `INSERT IGNORE INTO %n.%n (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES (%?,%?,%?,%?)`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } return ErrCannotUser.GenWithStackByArgs("GRANT ROLE", user.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -933,10 +961,11 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } sqlExecutor := sysSession.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { exists, err := userExists(e.ctx, user.Username, user.Hostname) if err != nil { @@ -952,58 +981,66 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } // begin a transaction to delete a user. - sql := fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete privileges from mysql.global_priv - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } continue } // delete privileges from mysql.db - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete privileges from mysql.tables_priv - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete relationship from mysql.role_edges - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE TO_HOST = '%s' and TO_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE TO_HOST = %? and TO_USER = %?;`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE FROM_HOST = '%s' and FROM_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE FROM_HOST = %? and FROM_USER = %?;`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete relationship from mysql.default_roles - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE DEFAULT_ROLE_HOST = '%s' and DEFAULT_ROLE_USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE DEFAULT_ROLE_HOST = %? and DEFAULT_ROLE_USER = %?;`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE HOST = '%s' and USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE HOST = %? and USER = %?;`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } @@ -1011,11 +1048,11 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } if len(failedUsers) == 0 { - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } } else { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } if s.IsDropRole { @@ -1028,8 +1065,12 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } func userExists(ctx sessionctx.Context, name string, host string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.UserTable, name, host) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, host) + if err != nil { + return false, err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return false, err } @@ -1062,8 +1103,12 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { } // update mysql.user - sql := fmt.Sprintf(`UPDATE %s.%s SET authentication_string='%s' WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) return err } diff --git a/executor/trace.go b/executor/trace.go index bf9150f357081..8eb6592c4822f 100644 --- a/executor/trace.go +++ b/executor/trace.go @@ -132,24 +132,21 @@ func (e *TraceExec) nextRowJSON(ctx context.Context, se sqlexec.SQLExecutor, req } func (e *TraceExec) executeChild(ctx context.Context, se sqlexec.SQLExecutor) { - recordSets, err := se.Execute(ctx, e.stmtNode.Text()) - if len(recordSets) == 0 { - if err != nil { - var errCode uint16 - if te, ok := err.(*terror.Error); ok { - errCode = terror.ToSQLError(te).Code - } - logutil.Eventf(ctx, "execute with error(%d): %s", errCode, err.Error()) - } else { - logutil.Eventf(ctx, "execute done, modify row: %d", e.ctx.GetSessionVars().StmtCtx.AffectedRows()) + rs, err := se.ExecuteStmt(ctx, e.stmtNode) + if err != nil { + var errCode uint16 + if te, ok := err.(*terror.Error); ok { + errCode = terror.ToSQLError(te).Code } + logutil.Eventf(ctx, "execute with error(%d): %s", errCode, err.Error()) } - for _, rs := range recordSets { + if rs != nil { drainRecordSet(ctx, e.ctx, rs) if err = rs.Close(); err != nil { logutil.Logger(ctx).Error("run trace close result with error", zap.Error(err)) } } + logutil.Eventf(ctx, "execute done, modify row: %d", e.ctx.GetSessionVars().StmtCtx.AffectedRows()) } func drainRecordSet(ctx context.Context, sctx sessionctx.Context, rs sqlexec.RecordSet) { diff --git a/executor/utils.go b/executor/utils.go new file mode 100644 index 0000000000000..fbc9ab4dcff30 --- /dev/null +++ b/executor/utils.go @@ -0,0 +1,46 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import "strings" + +func setFromString(value string) []string { + if len(value) == 0 { + return nil + } + return strings.Split(value, ",") +} + +// addToSet add a value to the set, e.g: +// addToSet("Select,Insert,Update", "Update") returns "Select,Insert,Update". +func addToSet(set []string, value string) []string { + for _, v := range set { + if v == value { + return set + } + } + return append(set, value) +} + +// deleteFromSet delete the value from the set, e.g: +// deleteFromSet("Select,Insert,Update", "Update") returns "Select,Insert". +func deleteFromSet(set []string, value string) []string { + for i, v := range set { + if v == value { + copy(set[i:], set[i+1:]) + return set[:len(set)-1] + } + } + return set +} diff --git a/telemetry/data_cluster_hardware.go b/telemetry/data_cluster_hardware.go index 318e1ba63f48f..eb380e9503ac7 100644 --- a/telemetry/data_cluster_hardware.go +++ b/telemetry/data_cluster_hardware.go @@ -14,6 +14,7 @@ package telemetry import ( + "context" "regexp" "sort" "strings" @@ -66,7 +67,12 @@ func normalizeFieldName(name string) string { } func getClusterHardware(ctx sessionctx.Context) ([]*clusterHardwareItem, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + if err != nil { + return nil, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_cluster_info.go b/telemetry/data_cluster_info.go index fdb1be6bafc27..46b8cfb8f7b47 100644 --- a/telemetry/data_cluster_info.go +++ b/telemetry/data_cluster_info.go @@ -14,6 +14,8 @@ package telemetry import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/sqlexec" @@ -33,7 +35,12 @@ type clusterInfoItem struct { func getClusterInfo(ctx sessionctx.Context) ([]*clusterInfoItem, error) { // Explicitly list all field names instead of using `*` to avoid potential leaking sensitive info when adding new fields in future. - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + if err != nil { + return nil, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, errors.Trace(err) } diff --git a/util/mock/context.go b/util/mock/context.go index f049a05beb842..e2461c7bc8446 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -20,6 +20,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/owner" @@ -61,6 +62,11 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, return nil, errors.Errorf("Not Support.") } +// ExecuteStmt implements sqlexec.SQLExecutor ExecuteStmt interface. +func (c *Context) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + return nil, errors.Errorf("Not Supported.") +} + // ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface. func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { return nil, errors.Errorf("Not Supported.") diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index cc23bda405b9f..597873d050151 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -89,6 +89,7 @@ type SQLExecutor interface { Execute(ctx context.Context, sql string) ([]RecordSet, error) // ExecuteInternal means execute sql as the internal sql. ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error) + ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (RecordSet, error) } // SQLParser is an interface provides parsing sql statement.