Skip to content

Commit fbd8dd1

Browse files
dbjoaiamzhoug37
authored andcommitted
plan: support ? in Order By / Group By / Limit Offset clauses (pingcap#8206)
1 parent 1bc7b9c commit fbd8dd1

10 files changed

+234
-28
lines changed

executor/prepared.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
plannercore "github.com/pingcap/tidb/planner/core"
2727
"github.com/pingcap/tidb/sessionctx"
2828
"github.com/pingcap/tidb/sessionctx/variable"
29-
"github.com/pingcap/tidb/types"
3029
"github.com/pingcap/tidb/types/parser_driver"
3130
"github.com/pingcap/tidb/util/chunk"
3231
"github.com/pingcap/tidb/util/sqlexec"
@@ -161,7 +160,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error {
161160

162161
// We try to build the real statement of preparedStmt.
163162
for i := range prepared.Params {
164-
prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0)
163+
prepared.Params[i].(*driver.ParamMarkerExpr).Datum.SetNull()
165164
}
166165
var p plannercore.Plan
167166
p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is)

executor/prepared_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -688,3 +688,60 @@ func (s *testSuite) TestPrepareDealloc(c *C) {
688688
tk.MustExec("deallocate prepare stmt4")
689689
c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 0)
690690
}
691+
692+
func (s *testSuite) TestPreparedIssue8153(c *C) {
693+
orgEnable := plannercore.PreparedPlanCacheEnabled()
694+
orgCapacity := plannercore.PreparedPlanCacheCapacity
695+
orgMemGuardRatio := plannercore.PreparedPlanCacheMemoryGuardRatio
696+
orgMaxMemory := plannercore.PreparedPlanCacheMaxMemory
697+
defer func() {
698+
plannercore.SetPreparedPlanCache(orgEnable)
699+
plannercore.PreparedPlanCacheCapacity = orgCapacity
700+
plannercore.PreparedPlanCacheMemoryGuardRatio = orgMemGuardRatio
701+
plannercore.PreparedPlanCacheMaxMemory = orgMaxMemory
702+
}()
703+
flags := []bool{false, true}
704+
for _, flag := range flags {
705+
var err error
706+
plannercore.SetPreparedPlanCache(flag)
707+
plannercore.PreparedPlanCacheCapacity = 100
708+
plannercore.PreparedPlanCacheMemoryGuardRatio = 0.1
709+
plannercore.PreparedPlanCacheMaxMemory, err = memory.MemTotal()
710+
tk := testkit.NewTestKit(c, s.store)
711+
tk.MustExec("use test")
712+
tk.MustExec("drop table if exists t")
713+
tk.MustExec("create table t (a int, b int)")
714+
tk.MustExec("insert into t (a, b) values (1,3), (2,2), (3,1)")
715+
716+
tk.MustExec(`prepare stmt from 'select * from t order by ? asc'`)
717+
r := tk.MustQuery(`execute stmt using @param;`)
718+
r.Check(testkit.Rows("1 3", "2 2", "3 1"))
719+
720+
tk.MustExec(`set @param = 1`)
721+
r = tk.MustQuery(`execute stmt using @param;`)
722+
r.Check(testkit.Rows("1 3", "2 2", "3 1"))
723+
724+
tk.MustExec(`set @param = 2`)
725+
r = tk.MustQuery(`execute stmt using @param;`)
726+
r.Check(testkit.Rows("3 1", "2 2", "1 3"))
727+
728+
tk.MustExec(`set @param = 3`)
729+
_, err = tk.Exec(`execute stmt using @param;`)
730+
c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'")
731+
732+
tk.MustExec(`set @param = '##'`)
733+
r = tk.MustQuery(`execute stmt using @param;`)
734+
r.Check(testkit.Rows("1 3", "2 2", "3 1"))
735+
736+
tk.MustExec("insert into t (a, b) values (1,1), (1,2), (2,1), (2,3), (3,2), (3,3)")
737+
tk.MustExec(`prepare stmt from 'select ?, sum(a) from t group by ?'`)
738+
739+
tk.MustExec(`set @a=1,@b=1`)
740+
r = tk.MustQuery(`execute stmt using @a,@b;`)
741+
r.Check(testkit.Rows("1 18"))
742+
743+
tk.MustExec(`set @a=1,@b=2`)
744+
_, err = tk.Exec(`execute stmt using @a,@b;`)
745+
c.Assert(err.Error(), Equals, "[planner:1056]Can't group on 'sum(a)'")
746+
}
747+
}

expression/simple_rewriter.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo
151151
}
152152
case *driver.ParamMarkerExpr:
153153
var value Expression
154-
value, sr.err = GetParamExpression(sr.ctx, v, sr.useCache())
154+
value, sr.err = GetParamExpression(sr.ctx, v)
155155
if sr.err != nil {
156156
return retNode, false
157157
}

expression/util.go

+50-1
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,8 @@ func DatumToConstant(d types.Datum, tp byte) *Constant {
511511
}
512512

513513
// GetParamExpression generate a getparam function expression.
514-
func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCache bool) (Expression, error) {
514+
func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expression, error) {
515+
useCache := ctx.GetSessionVars().StmtCtx.UseCache
515516
tp := types.NewFieldType(mysql.TypeUnspecified)
516517
types.DefaultParamTypeForValue(v.GetValue(), tp)
517518
value := &Constant{Value: v.Datum, RetType: tp}
@@ -526,3 +527,51 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCa
526527
}
527528
return value, nil
528529
}
530+
531+
// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr.
532+
func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr {
533+
return &ast.PositionExpr{P: p}
534+
}
535+
536+
// PosFromPositionExpr generates a position value from PositionExpr.
537+
func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) {
538+
if v.P == nil {
539+
return v.N, false, nil
540+
}
541+
value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr))
542+
if err != nil {
543+
return 0, true, err
544+
}
545+
pos, isNull, err := GetIntFromConstant(ctx, value)
546+
if err != nil || isNull {
547+
return 0, true, errors.Trace(err)
548+
}
549+
return pos, false, nil
550+
}
551+
552+
// GetStringFromConstant gets a string value from the Constant expression.
553+
func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bool, error) {
554+
con, ok := value.(*Constant)
555+
if !ok {
556+
err := errors.Errorf("Not a Constant expression %+v", value)
557+
return "", true, errors.Trace(err)
558+
}
559+
str, isNull, err := con.EvalString(ctx, chunk.Row{})
560+
if err != nil || isNull {
561+
return "", true, errors.Trace(err)
562+
}
563+
return str, false, nil
564+
}
565+
566+
// GetIntFromConstant gets an interger value from the Constant expression.
567+
func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) {
568+
str, isNull, err := GetStringFromConstant(ctx, value)
569+
if err != nil || isNull {
570+
return 0, true, errors.Trace(err)
571+
}
572+
intNum, err := strconv.Atoi(str)
573+
if err != nil {
574+
return 0, true, nil
575+
}
576+
return intNum, false, nil
577+
}

planner/core/cacheable_checker.go

+14
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren
5555
checker.cacheable = false
5656
return in, true
5757
}
58+
case *ast.OrderByClause:
59+
for _, item := range node.Items {
60+
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
61+
checker.cacheable = false
62+
return in, true
63+
}
64+
}
65+
case *ast.GroupByClause:
66+
for _, item := range node.Items {
67+
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
68+
checker.cacheable = false
69+
return in, true
70+
}
71+
}
5872
case *ast.Limit:
5973
if node.Count != nil {
6074
if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker {

planner/core/cacheable_checker_test.go

+14
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,18 @@ func (s *testCacheableSuite) TestCacheable(c *C) {
177177
Limit: limitStmt,
178178
}
179179
c.Assert(Cacheable(stmt), IsTrue)
180+
181+
paramExpr := &driver.ParamMarkerExpr{}
182+
orderByClause := &ast.OrderByClause{Items: []*ast.ByItem{{Expr: paramExpr}}}
183+
stmt = &ast.SelectStmt{
184+
OrderBy: orderByClause,
185+
}
186+
c.Assert(Cacheable(stmt), IsFalse)
187+
188+
valExpr := &driver.ValueExpr{}
189+
orderByClause = &ast.OrderByClause{Items: []*ast.ByItem{{Expr: valExpr}}}
190+
stmt = &ast.SelectStmt{
191+
OrderBy: orderByClause,
192+
}
193+
c.Assert(Cacheable(stmt), IsTrue)
180194
}

planner/core/errors.go

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const (
2828

2929
codeWrongUsage = mysql.ErrWrongUsage
3030
codeAmbiguous = mysql.ErrNonUniq
31+
codeUnknown = mysql.ErrUnknown
3132
codeUnknownColumn = mysql.ErrBadField
3233
codeUnknownTable = mysql.ErrUnknownTable
3334
codeWrongArguments = mysql.ErrWrongArguments
@@ -64,6 +65,7 @@ var (
6465

6566
ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage])
6667
ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq])
68+
ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown])
6769
ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField])
6870
ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable])
6971
ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments])

planner/core/expression_rewriter.go

+20-4
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
756756
er.ctxStack = append(er.ctxStack, value)
757757
case *driver.ParamMarkerExpr:
758758
var value expression.Expression
759-
value, er.err = expression.GetParamExpression(er.ctx, v, er.useCache())
759+
value, er.err = expression.GetParamExpression(er.ctx, v)
760760
if er.err != nil {
761761
return retNode, false
762762
}
@@ -941,10 +941,26 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) {
941941
}
942942

943943
func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) {
944-
if v.N > 0 && v.N <= er.schema.Len() {
945-
er.ctxStack = append(er.ctxStack, er.schema.Columns[v.N-1])
944+
pos := v.N
945+
str := strconv.Itoa(pos)
946+
if v.P != nil {
947+
stkLen := len(er.ctxStack)
948+
val := er.ctxStack[stkLen-1]
949+
intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val)
950+
str = "?"
951+
if err == nil {
952+
if isNull {
953+
return
954+
}
955+
pos = intNum
956+
er.ctxStack = er.ctxStack[:stkLen-1]
957+
}
958+
er.err = err
959+
}
960+
if er.err == nil && pos > 0 && pos <= er.schema.Len() {
961+
er.ctxStack = append(er.ctxStack, er.schema.Columns[pos-1])
946962
} else {
947-
er.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[er.b.curClause])
963+
er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause])
948964
}
949965
}
950966

0 commit comments

Comments
 (0)