Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: adapt new api for the executor package (#22644) #23156

Merged
merged 3 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions executor/brie.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error)

// Execute implements glue.Session
func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error {
// FIXME: br relies on a deprecated API, it may be unsafe
_, err := gs.se.(sqlexec.SQLExecutor).Execute(ctx, sql)
return err
}
Expand Down
8 changes: 6 additions & 2 deletions executor/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,12 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx
zap.String("database", fullti.Schema.O),
zap.String("table", fullti.Name.O),
)
sql := fmt.Sprintf("admin check table `%s`.`%s`", fullti.Schema.O, fullti.Name.O)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O)
if err != nil {
return err
}
_, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
342 changes: 147 additions & 195 deletions executor/grant.go

Large diffs are not rendered by default.

16 changes: 14 additions & 2 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,16 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex
}

func getRowCountAllTable(ctx sessionctx.Context) (map[int64]uint64, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, count from mysql.stats_meta")
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, count from mysql.stats_meta")
if err != nil {
return nil, err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return nil, err
}

rowCountMap := make(map[int64]uint64, len(rows))
for _, row := range rows {
tableID := row.GetInt64(0)
Expand All @@ -173,10 +179,16 @@ type tableHistID struct {
}

func getColLengthAllTables(ctx sessionctx.Context) (map[tableHistID]uint64, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0")
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0")
if err != nil {
return nil, err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return nil, err
}

colLengthMap := make(map[tableHistID]uint64, len(rows))
for _, row := range rows {
tableID := row.GetInt64(0)
Expand Down
8 changes: 7 additions & 1 deletion executor/inspection_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ func (n *metricNode) getLabelValue(label string) *metricValue {
}

func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error {
rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(context.Background(), query)
exec := pb.sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), query)
if err != nil {
return err
}

rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
154 changes: 113 additions & 41 deletions executor/inspection_result.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion executor/inspection_summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,12 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc
sql = fmt.Sprintf("select avg(value),min(value),max(value) from `%s`.`%s` %s",
util.MetricSchemaName.L, name, cond)
}
rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql)
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, sql)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
Expand Down
16 changes: 13 additions & 3 deletions executor/metrics_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ type MetricsSummaryRetriever struct {
retrieved bool
}

func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Context) ([][]types.Datum, error) {
func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) {
if e.retrieved || e.extractor.SkipRequest {
return nil, nil
}
Expand Down Expand Up @@ -229,7 +229,12 @@ func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Co
name, util.MetricSchemaName.L, condition)
}

rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, sql)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
Expand Down Expand Up @@ -306,7 +311,12 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess
sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%s`.`%s` %s",
util.MetricSchemaName.L, name, cond)
}
rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql)
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, sql)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
Expand Down
8 changes: 6 additions & 2 deletions executor/opt_rule_blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e

// LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist.
func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) {
sql := "select HIGH_PRIORITY name from mysql.opt_rule_blacklist"
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist")
if err != nil {
return err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error {
err error
)
if sqlParser, ok := e.ctx.(sqlexec.SQLParser); ok {
// FIXME: ok... yet another parse API, may need some api interface clean.
stmts, err = sqlParser.ParseSQL(e.sqlText, charset, collation)
} else {
p := parser.New()
Expand Down
8 changes: 6 additions & 2 deletions executor/reload_expr_pushdown_blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu

// LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist.
func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) {
sql := "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist"
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist")
if err != nil {
return err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
103 changes: 87 additions & 16 deletions executor/revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ package executor

import (
"context"
"fmt"
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -73,15 +73,15 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
defer func() {
if !isCommit {
_, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback")
_, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback")
if err != nil {
logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err))
}
}
e.releaseSysSession(internalSession)
}()

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin")
_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin")
if err != nil {
return err
}
Expand All @@ -103,7 +103,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
}

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit")
_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit")
if err != nil {
return err
}
Expand Down Expand Up @@ -166,12 +166,15 @@ func (e *RevokeExec) revokePriv(internalSession sessionctx.Context, priv *ast.Pr
}

func (e *RevokeExec) revokeGlobalPriv(internalSession sessionctx.Context, priv *ast.PrivElem, user, host string) error {
asgns, err := composeGlobalPrivUpdate(priv.Priv, "N")
sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable)
err := composeGlobalPrivUpdate(sql, priv.Priv, "N")
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user, host)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%?", user, host)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
return err
}

Expand All @@ -180,12 +183,16 @@ func (e *RevokeExec) revokeDBPriv(internalSession sessionctx.Context, priv *ast.
if len(dbName) == 0 {
dbName = e.ctx.GetSessionVars().CurrentDB
}
asgns, err := composeDBPrivUpdate(priv.Priv, "N")

sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable)
err := composeDBPrivUpdate(sql, priv.Priv, "N")
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, userName, host, dbName)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", userName, host, dbName)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
return err
}

Expand All @@ -194,12 +201,16 @@ func (e *RevokeExec) revokeTablePriv(internalSession sessionctx.Context, priv *a
if err != nil {
return err
}
asgns, err := composeTablePrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O)

sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable)
err = composeTablePrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user, host, dbName, tbl.Meta().Name.O)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user, host, dbName, tbl.Meta().Name.O)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
return err
}

Expand All @@ -208,20 +219,80 @@ func (e *RevokeExec) revokeColumnPriv(internalSession sessionctx.Context, priv *
if err != nil {
return err
}
sql := new(strings.Builder)
for _, c := range priv.Cols {
col := table.FindCol(tbl.Cols(), c.Name.L)
if col == nil {
return errors.Errorf("Unknown column: %s", c)
}
asgns, err := composeColumnPrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O)

sql.Reset()
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable)
err = composeColumnPrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user, host, dbName, tbl.Meta().Name.O, col.Name.O)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
if err != nil {
return err
}
}
return nil
}

func privUpdateForRevoke(cur []string, priv mysql.PrivilegeType) ([]string, error) {
p, ok := mysql.Priv2SetStr[priv]
if !ok {
return nil, errors.Errorf("Unknown priv: %v", priv)
}
cur = deleteFromSet(cur, p)
return cur, nil
}

func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error {
var newTablePriv, newColumnPriv []string

if priv != mysql.AllPriv {
currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl)
if err != nil {
return err
}

newTablePriv = setFromString(currTablePriv)
newTablePriv, err = privUpdateForRevoke(newTablePriv, priv)
if err != nil {
return err
}

newColumnPriv = setFromString(currColumnPriv)
newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv)
if err != nil {
return err
}
}

sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String())
return nil
}

func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error {
var newColumnPriv []string

if priv != mysql.AllPriv {
currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)

newColumnPriv = setFromString(currColumnPriv)
newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv)
if err != nil {
return err
}
}

sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ","))
return nil
}
Loading