Skip to content

Commit

Permalink
*: rewrite origin SQL with default DB for SQL bindings (#21275)
Browse files Browse the repository at this point in the history
  • Loading branch information
rebelice authored Jan 6, 2021
1 parent 8ef1031 commit 51794e9
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 152 deletions.
312 changes: 199 additions & 113 deletions bindinfo/bind_test.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bindinfo/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (br *BindRecord) shallowCopy() *BindRecord {
}

func (br *BindRecord) isSame(other *BindRecord) bool {
return br.OriginalSQL == other.OriginalSQL && br.Db == other.Db
return br.OriginalSQL == other.OriginalSQL
}

var statusIndex = map[string]int{
Expand Down
23 changes: 11 additions & 12 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func (h *BindHandle) DropBindRecord(originalSQL, db string, binding *Binding) (e
func (h *BindHandle) lockBindInfoTable() error {
// h.sctx already locked.
exec, _ := h.sctx.Context.(sqlexec.SQLExecutor)
_, err := exec.ExecuteInternal(context.TODO(), h.lockBindInfoSQL())
_, err := exec.ExecuteInternal(context.TODO(), h.LockBindInfoSQL())
return err
}

Expand Down Expand Up @@ -539,7 +539,7 @@ func (c cache) removeDeletedBindRecord(hash string, meta *BindRecord) {
func (c cache) setBindRecord(hash string, meta *BindRecord) {
metas := c[hash]
for i := range metas {
if metas[i].Db == meta.Db && metas[i].OriginalSQL == meta.OriginalSQL {
if metas[i].OriginalSQL == meta.OriginalSQL {
metas[i] = meta
return
}
Expand Down Expand Up @@ -568,7 +568,7 @@ func copyBindRecordUpdateMap(oldMap map[string]*bindRecordUpdate) map[string]*bi
func (c cache) getBindRecord(hash, normdOrigSQL, db string) *BindRecord {
bindRecords := c[hash]
for _, bindRecord := range bindRecords {
if bindRecord.OriginalSQL == normdOrigSQL && bindRecord.Db == db {
if bindRecord.OriginalSQL == normdOrigSQL {
return bindRecord
}
}
Expand All @@ -577,9 +577,8 @@ func (c cache) getBindRecord(hash, normdOrigSQL, db string) *BindRecord {

func (h *BindHandle) deleteBindInfoSQL(normdOrigSQL, db, bindSQL string) string {
sql := fmt.Sprintf(
`DELETE FROM mysql.bind_info WHERE original_sql=%s AND LOWER(default_db)=%s`,
`DELETE FROM mysql.bind_info WHERE original_sql=%s`,
expression.Quote(normdOrigSQL),
expression.Quote(db),
)
if bindSQL == "" {
return sql
Expand All @@ -601,20 +600,19 @@ func (h *BindHandle) insertBindInfoSQL(orignalSQL string, db string, info Bindin
)
}

// lockBindInfoSQL simulates LOCK TABLE by updating a same row in each pessimistic transaction.
func (h *BindHandle) lockBindInfoSQL() string {
// LockBindInfoSQL simulates LOCK TABLE by updating a same row in each pessimistic transaction.
func (h *BindHandle) LockBindInfoSQL() string {
return fmt.Sprintf("UPDATE mysql.bind_info SET source=%s WHERE original_sql=%s",
expression.Quote(Builtin),
expression.Quote(BuiltinPseudoSQL4BindLock))
}

func (h *BindHandle) logicalDeleteBindInfoSQL(originalSQL, db string, updateTs types.Time, bindingSQL string) string {
updateTsStr := updateTs.String()
sql := fmt.Sprintf(`UPDATE mysql.bind_info SET status=%s,update_time=%s WHERE original_sql=%s and LOWER(default_db)=%s and update_time<%s`,
sql := fmt.Sprintf(`UPDATE mysql.bind_info SET status=%s,update_time=%s WHERE original_sql=%s and update_time<%s`,
expression.Quote(deleted),
expression.Quote(updateTsStr),
expression.Quote(originalSQL),
expression.Quote(db),
expression.Quote(updateTsStr))
if bindingSQL == "" {
return sql
Expand All @@ -635,12 +633,12 @@ func (h *BindHandle) CaptureBaselines() {
if insertStmt, ok := stmt.(*ast.InsertStmt); ok && insertStmt.Select == nil {
continue
}
normalizedSQL, digest := parser.NormalizeDigest(bindableStmt.Query)
dbName := utilparser.GetDefaultDB(stmt, bindableStmt.Schema)
normalizedSQL, digest := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, dbName))
if r := h.GetBindRecord(digest, normalizedSQL, dbName); r != nil && r.HasUsingBinding() {
continue
}
bindSQL := GenerateBindSQL(context.TODO(), stmt, bindableStmt.PlanHint, true)
bindSQL := GenerateBindSQL(context.TODO(), stmt, bindableStmt.PlanHint, true, dbName)
if bindSQL == "" {
continue
}
Expand Down Expand Up @@ -680,7 +678,7 @@ func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
}

// GenerateBindSQL generates binding sqls from stmt node and plan hints.
func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string, captured bool) string {
func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string, captured bool, defaultDB string) string {
// If would be nil for very simple cases such as point get, we do not need to evolve for them.
if planHint == "" {
return ""
Expand All @@ -699,6 +697,7 @@ func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string
hint.BindHint(stmtNode, &hint.HintsSet{})
var sb strings.Builder
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
restoreCtx.DefaultDB = defaultDB
err := stmtNode.Restore(restoreCtx)
if err != nil {
logutil.Logger(ctx).Debug("[sql-bind] restore SQL failed when generating bind SQL", zap.Error(err))
Expand Down
2 changes: 1 addition & 1 deletion bindinfo/session_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (h *SessionHandle) GetBindRecord(normdOrigSQL, db string) *BindRecord {
hash := parser.DigestNormalized(normdOrigSQL)
bindRecords := h.ch[hash]
for _, bindRecord := range bindRecords {
if bindRecord.OriginalSQL == normdOrigSQL && bindRecord.Db == db {
if bindRecord.OriginalSQL == normdOrigSQL {
return bindRecord
}
}
Expand Down
37 changes: 33 additions & 4 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -751,23 +751,52 @@ func (b *PlanBuilder) buildSet(ctx context.Context, v *ast.SetStmt) (Plan, error
func (b *PlanBuilder) buildDropBindPlan(v *ast.DropBindingStmt) (Plan, error) {
p := &SQLBindPlan{
SQLBindOp: OpSQLBindDrop,
NormdOrigSQL: parser.Normalize(v.OriginNode.Text()),
NormdOrigSQL: parser.Normalize(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB)),
IsGlobal: v.GlobalScope,
Db: utilparser.GetDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB),
}
if v.HintedNode != nil {
p.BindSQL = v.HintedNode.Text()
p.BindSQL = utilparser.RestoreWithDefaultDB(v.HintedNode, b.ctx.GetSessionVars().CurrentDB)
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil)
return p, nil
}

func checkHintedSQL(sql, charset, collation, db string) error {
p := parser.New()
hintsSet, _, warns, err := hint.ParseHintsSet(p, sql, charset, collation, db)
if err != nil {
return err
}
hintsStr, err := hintsSet.Restore()
if err != nil {
return err
}
// For `create global binding for select * from t using select * from t`, we allow it though hintsStr is empty.
// For `create global binding for select * from t using select /*+ non_exist_hint() */ * from t`,
// the hint is totally invalid, we escalate warning to error.
if hintsStr == "" && len(warns) > 0 {
return warns[0]
}
return nil
}

func (b *PlanBuilder) buildCreateBindPlan(v *ast.CreateBindingStmt) (Plan, error) {
charSet, collation := b.ctx.GetSessionVars().GetCharsetInfo()

// Because we use HintedNode.Restore instead of HintedNode.Text, so we need do some check here
// For example, if HintedNode.Text is `select /*+ non_exist_hint() */ * from t` and the current DB is `test`,
// the HintedNode.Restore will be `select * from test . t`.
// In other words, illegal hints will be deleted during restore. We can't check hinted SQL after restore.
// So we need check here.
if err := checkHintedSQL(v.HintedNode.Text(), charSet, collation, b.ctx.GetSessionVars().CurrentDB); err != nil {
return nil, err
}

p := &SQLBindPlan{
SQLBindOp: OpSQLBindCreate,
NormdOrigSQL: parser.Normalize(v.OriginNode.Text()),
BindSQL: v.HintedNode.Text(),
NormdOrigSQL: parser.Normalize(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB)),
BindSQL: utilparser.RestoreWithDefaultDB(v.HintedNode, b.ctx.GetSessionVars().CurrentDB),
IsGlobal: v.GlobalScope,
BindStmt: v.HintedNode,
Db: utilparser.GetDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB),
Expand Down
11 changes: 6 additions & 5 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
driver "github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/domainutil"
utilparser "github.com/pingcap/tidb/util/parser"
)

// PreprocessOpt presents optional parameters to `Preprocess` method.
Expand Down Expand Up @@ -179,14 +180,14 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
p.stmtTp = TypeCreate
EraseLastSemicolon(node.OriginNode)
EraseLastSemicolon(node.HintedNode)
p.checkBindGrammar(node.OriginNode, node.HintedNode)
p.checkBindGrammar(node.OriginNode, node.HintedNode, p.ctx.GetSessionVars().CurrentDB)
return in, true
case *ast.DropBindingStmt:
p.stmtTp = TypeDrop
EraseLastSemicolon(node.OriginNode)
if node.HintedNode != nil {
EraseLastSemicolon(node.HintedNode)
p.checkBindGrammar(node.OriginNode, node.HintedNode)
p.checkBindGrammar(node.OriginNode, node.HintedNode, p.ctx.GetSessionVars().CurrentDB)
}
return in, true
case *ast.RecoverTableStmt, *ast.FlashBackTableStmt:
Expand Down Expand Up @@ -291,7 +292,7 @@ func bindableStmtType(node ast.StmtNode) byte {
return TypeInvalid
}

func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode) {
func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, defaultDB string) {
origTp := bindableStmtType(originNode)
hintedTp := bindableStmtType(hintedNode)
if origTp == TypeInvalid || hintedTp == TypeInvalid {
Expand All @@ -309,8 +310,8 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode) {
return
}
}
originSQL := parser.Normalize(originNode.Text())
hintedSQL := parser.Normalize(hintedNode.Text())
originSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(originNode, defaultDB))
hintedSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(hintedNode, defaultDB))
if originSQL != hintedSQL {
p.err = errors.Errorf("hinted sql and origin sql don't match when hinted sql erase the hint info, after erase hint info, originSQL:%s, hintedSQL:%s", originSQL, hintedSQL)
}
Expand Down
54 changes: 40 additions & 14 deletions planner/optimize.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/hint"
"github.com/pingcap/tidb/util/logutil"
utilparser "github.com/pingcap/tidb/util/parser"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -273,13 +274,26 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in
return finalPlan, names, cost, err
}

func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode) (ast.StmtNode, string, string) {
func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) (ast.StmtNode, string, string) {
switch x := stmtNode.(type) {
case *ast.ExplainStmt:
// This function is only used to find bind record.
// For some SQLs, such as `explain select * from t`, they will be entered here many times,
// but some of them do not want to obtain bind record.
// The difference between them is whether len(x.Text()) is empty. They cannot be distinguished by stmt.restore.
// For these cases, we need return "" as normalize SQL and hash.
if len(x.Text()) == 0 {
return x.Stmt, "", ""
}
switch x.Stmt.(type) {
case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt:
plannercore.EraseLastSemicolon(x)
normalizeExplainSQL := parser.Normalize(x.Text())
var normalizeExplainSQL string
if specifiledDB != "" {
normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB))
} else {
normalizeExplainSQL = parser.Normalize(x.Text())
}
idx := int(0)
switch n := x.Stmt.(type) {
case *ast.SelectStmt:
Expand All @@ -300,7 +314,12 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode) (ast.StmtNode, strin
return x.Stmt, normalizeSQL, hash
case *ast.SetOprStmt:
plannercore.EraseLastSemicolon(x)
normalizeExplainSQL := parser.Normalize(x.Text())
var normalizeExplainSQL string
if specifiledDB != "" {
normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB))
} else {
normalizeExplainSQL = parser.Normalize(x.Text())
}
idx := strings.Index(normalizeExplainSQL, "select")
parenthesesIdx := strings.Index(normalizeExplainSQL, "(")
if parenthesesIdx != -1 && parenthesesIdx < idx {
Expand All @@ -312,7 +331,20 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode) (ast.StmtNode, strin
}
case *ast.SelectStmt, *ast.SetOprStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt:
plannercore.EraseLastSemicolon(x)
normalizedSQL, hash := parser.NormalizeDigest(x.Text())
// This function is only used to find bind record.
// For some SQLs, such as `explain select * from t`, they will be entered here many times,
// but some of them do not want to obtain bind record.
// The difference between them is whether len(x.Text()) is empty. They cannot be distinguished by stmt.restore.
// For these cases, we need return "" as normalize SQL and hash.
if len(x.Text()) == 0 {
return x, "", ""
}
var normalizedSQL, hash string
if specifiledDB != "" {
normalizedSQL, hash = parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB))
} else {
normalizedSQL, hash = parser.NormalizeDigest(x.Text())
}
return x, normalizedSQL, hash
}
return nil, "", ""
Expand All @@ -323,15 +355,12 @@ func getBindRecord(ctx sessionctx.Context, stmt ast.StmtNode) (*bindinfo.BindRec
if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil {
return nil, ""
}
stmtNode, normalizedSQL, hash := extractSelectAndNormalizeDigest(stmt)
stmtNode, normalizedSQL, hash := extractSelectAndNormalizeDigest(stmt, ctx.GetSessionVars().CurrentDB)
if stmtNode == nil {
return nil, ""
}
sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle)
bindRecord := sessionHandle.GetBindRecord(normalizedSQL, ctx.GetSessionVars().CurrentDB)
if bindRecord == nil {
bindRecord = sessionHandle.GetBindRecord(normalizedSQL, "")
}
bindRecord := sessionHandle.GetBindRecord(normalizedSQL, "")
if bindRecord != nil {
if bindRecord.HasUsingBinding() {
return bindRecord, metrics.ScopeSession
Expand All @@ -342,10 +371,7 @@ func getBindRecord(ctx sessionctx.Context, stmt ast.StmtNode) (*bindinfo.BindRec
if globalHandle == nil {
return nil, ""
}
bindRecord = globalHandle.GetBindRecord(hash, normalizedSQL, ctx.GetSessionVars().CurrentDB)
if bindRecord == nil {
bindRecord = globalHandle.GetBindRecord(hash, normalizedSQL, "")
}
bindRecord = globalHandle.GetBindRecord(hash, normalizedSQL, "")
return bindRecord, metrics.ScopeGlobal
}

Expand All @@ -364,7 +390,7 @@ func handleInvalidBindRecord(ctx context.Context, sctx sessionctx.Context, level
}

func handleEvolveTasks(ctx context.Context, sctx sessionctx.Context, br *bindinfo.BindRecord, stmtNode ast.StmtNode, planHint string) {
bindSQL := bindinfo.GenerateBindSQL(ctx, stmtNode, planHint, false)
bindSQL := bindinfo.GenerateBindSQL(ctx, stmtNode, planHint, false, br.Db)
if bindSQL == "" {
return
}
Expand Down
Loading

0 comments on commit 51794e9

Please sign in to comment.