Skip to content

Commit c0c1360

Browse files
dbjoazz-jason
authored andcommitted
plan: support ? in Order By / Group By / Limit Offset clauses (pingcap#8206)
1 parent 89b35b3 commit c0c1360

12 files changed

+262
-36
lines changed

executor/prepared.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
plannercore "github.com/pingcap/tidb/planner/core"
2626
"github.com/pingcap/tidb/sessionctx"
2727
"github.com/pingcap/tidb/sessionctx/variable"
28-
"github.com/pingcap/tidb/types"
2928
"github.com/pingcap/tidb/types/parser_driver"
3029
"github.com/pingcap/tidb/util/chunk"
3130
"github.com/pingcap/tidb/util/sqlexec"
@@ -165,7 +164,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error {
165164

166165
// We try to build the real statement of preparedStmt.
167166
for i := range prepared.Params {
168-
prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0)
167+
prepared.Params[i].(*driver.ParamMarkerExpr).Datum.SetNull()
169168
}
170169
var p plannercore.Plan
171170
p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is)

executor/prepared_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,54 @@ func generateBatchSQL(paramCount int) (sql string, paramSlice []interface{}) {
366366
}
367367
return "insert into t values " + strings.Join(placeholders, ","), params
368368
}
369+
370+
func (s *testSuite) TestPreparedIssue8153(c *C) {
371+
orgEnable := plannercore.PreparedPlanCacheEnabled()
372+
orgCapacity := plannercore.PreparedPlanCacheCapacity
373+
defer func() {
374+
plannercore.SetPreparedPlanCache(orgEnable)
375+
plannercore.PreparedPlanCacheCapacity = orgCapacity
376+
}()
377+
flags := []bool{false, true}
378+
for _, flag := range flags {
379+
var err error
380+
plannercore.SetPreparedPlanCache(flag)
381+
plannercore.PreparedPlanCacheCapacity = 100
382+
tk := testkit.NewTestKit(c, s.store)
383+
tk.MustExec("use test")
384+
tk.MustExec("drop table if exists t")
385+
tk.MustExec("create table t (a int, b int)")
386+
tk.MustExec("insert into t (a, b) values (1,3), (2,2), (3,1)")
387+
388+
tk.MustExec(`prepare stmt from 'select * from t order by ? asc'`)
389+
r := tk.MustQuery(`execute stmt using @param;`)
390+
r.Check(testkit.Rows("1 3", "2 2", "3 1"))
391+
392+
tk.MustExec(`set @param = 1`)
393+
r = tk.MustQuery(`execute stmt using @param;`)
394+
r.Check(testkit.Rows("1 3", "2 2", "3 1"))
395+
396+
tk.MustExec(`set @param = 2`)
397+
r = tk.MustQuery(`execute stmt using @param;`)
398+
r.Check(testkit.Rows("3 1", "2 2", "1 3"))
399+
400+
tk.MustExec(`set @param = 3`)
401+
_, err = tk.Exec(`execute stmt using @param;`)
402+
c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'")
403+
404+
tk.MustExec(`set @param = '##'`)
405+
r = tk.MustQuery(`execute stmt using @param;`)
406+
r.Check(testkit.Rows("1 3", "2 2", "3 1"))
407+
408+
tk.MustExec("insert into t (a, b) values (1,1), (1,2), (2,1), (2,3), (3,2), (3,3)")
409+
tk.MustExec(`prepare stmt from 'select ?, sum(a) from t group by ?'`)
410+
411+
tk.MustExec(`set @a=1,@b=1`)
412+
r = tk.MustQuery(`execute stmt using @a,@b;`)
413+
r.Check(testkit.Rows("1 18"))
414+
415+
tk.MustExec(`set @a=1,@b=2`)
416+
_, err = tk.Exec(`execute stmt using @a,@b;`)
417+
c.Assert(err.Error(), Equals, "[planner:1056]Can't group on 'sum(a)'")
418+
}
419+
}

expression/simple_rewriter.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,11 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo
156156
sr.inToExpression(len(v.List), v.Not, &v.Type)
157157
}
158158
case *driver.ParamMarkerExpr:
159-
tp := types.NewFieldType(mysql.TypeUnspecified)
160-
types.DefaultParamTypeForValue(v.GetValue(), tp)
161-
value := &Constant{Value: v.ValueExpr.Datum, RetType: tp}
159+
var value Expression
160+
value, sr.err = GetParamExpression(sr.ctx, v)
161+
if sr.err != nil {
162+
return retNode, false
163+
}
162164
sr.push(value)
163165
case *ast.RowExpr:
164166
sr.rowToScalarFunc(v)

expression/util.go

+72
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/pingcap/parser/terror"
2727
"github.com/pingcap/tidb/sessionctx"
2828
"github.com/pingcap/tidb/types"
29+
driver "github.com/pingcap/tidb/types/parser_driver"
2930
"github.com/pingcap/tidb/util/chunk"
3031
"github.com/pingcap/tidb/util/hack"
3132
"golang.org/x/tools/container/intsets"
@@ -555,3 +556,74 @@ func DisableParseJSONFlag4Expr(expr Expression) {
555556
}
556557
expr.GetType().Flag &= ^mysql.ParseToJSONFlag
557558
}
559+
560+
// DatumToConstant generates a Constant expression from a Datum.
561+
func DatumToConstant(d types.Datum, tp byte) *Constant {
562+
return &Constant{Value: d, RetType: types.NewFieldType(tp)}
563+
}
564+
565+
// GetParamExpression generate a getparam function expression.
566+
func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expression, error) {
567+
useCache := ctx.GetSessionVars().StmtCtx.UseCache
568+
tp := types.NewFieldType(mysql.TypeUnspecified)
569+
types.DefaultParamTypeForValue(v.GetValue(), tp)
570+
value := &Constant{Value: v.Datum, RetType: tp}
571+
if useCache {
572+
f, err := NewFunctionBase(ctx, ast.GetParam, &v.Type,
573+
DatumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong))
574+
if err != nil {
575+
return nil, errors.Trace(err)
576+
}
577+
f.GetType().Tp = v.Type.Tp
578+
value.DeferredExpr = f
579+
}
580+
return value, nil
581+
}
582+
583+
// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr.
584+
func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr {
585+
return &ast.PositionExpr{P: p}
586+
}
587+
588+
// PosFromPositionExpr generates a position value from PositionExpr.
589+
func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) {
590+
if v.P == nil {
591+
return v.N, false, nil
592+
}
593+
value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr))
594+
if err != nil {
595+
return 0, true, err
596+
}
597+
pos, isNull, err := GetIntFromConstant(ctx, value)
598+
if err != nil || isNull {
599+
return 0, true, errors.Trace(err)
600+
}
601+
return pos, false, nil
602+
}
603+
604+
// GetStringFromConstant gets a string value from the Constant expression.
605+
func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bool, error) {
606+
con, ok := value.(*Constant)
607+
if !ok {
608+
err := errors.Errorf("Not a Constant expression %+v", value)
609+
return "", true, errors.Trace(err)
610+
}
611+
str, isNull, err := con.EvalString(ctx, chunk.Row{})
612+
if err != nil || isNull {
613+
return "", true, errors.Trace(err)
614+
}
615+
return str, false, nil
616+
}
617+
618+
// GetIntFromConstant gets an interger value from the Constant expression.
619+
func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) {
620+
str, isNull, err := GetStringFromConstant(ctx, value)
621+
if err != nil || isNull {
622+
return 0, true, errors.Trace(err)
623+
}
624+
intNum, err := strconv.Atoi(str)
625+
if err != nil {
626+
return 0, true, nil
627+
}
628+
return intNum, false, nil
629+
}

go.mod

+2
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,5 @@ require (
8484
gopkg.in/natefinch/lumberjack.v2 v2.0.0
8585
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
8686
)
87+
88+
replace github.com/pingcap/parser => github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ github.com/pingcap/kvproto v0.0.0-20190826051950-fc8799546726 h1:AzGIEmaYVYMtmki
109109
github.com/pingcap/kvproto v0.0.0-20190826051950-fc8799546726/go.mod h1:0gwbe1F2iBIjuQ9AH0DbQhL+Dpr5GofU8fgYyXk+ykk=
110110
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ=
111111
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw=
112-
github.com/pingcap/parser v0.0.0-20190910040957-e998b3c52469 h1:JS/p4qMInVXTyV0kjFz+n0DBGn/n1T0cZDjEYHdTQow=
113-
github.com/pingcap/parser v0.0.0-20190910040957-e998b3c52469/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
114112
github.com/pingcap/pd v2.1.12+incompatible h1:6N3LBxx2aSZqT+IWEG730EDNDttP7dXO8J6yvBh+HXw=
115113
github.com/pingcap/pd v2.1.12+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E=
116114
github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible h1:e9Gi/LP9181HT3gBfSOeSBA+5JfemuE4aEAhqNgoE4k=
@@ -151,6 +149,8 @@ github.com/unrolled/render v0.0.0-20171102162132-65450fb6b2d3/go.mod h1:tu82oB5W
151149
github.com/xiang90/probing v0.0.0-20160813154853-07dd2e8dfe18 h1:MPPkRncZLN9Kh4MEFmbnK4h3BD7AUmskWv2+EeZJCCs=
152150
github.com/xiang90/probing v0.0.0-20160813154853-07dd2e8dfe18/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
153151
github.com/yookoala/realpath v1.0.0/go.mod h1:gJJMA9wuX7AcqLy1+ffPatSCySA1FQ2S8Ya9AIoYBpE=
152+
github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e h1:oxazCGeHJ+CdDGPGVeIpIBzJ4dw0DNqDI5wdXPVZb8Q=
153+
github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e/go.mod h1:mnf7H9ngMZzobilLo3+bu86/+DSlGQBnmse9S5K8PKQ=
154154
go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk=
155155
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
156156
go.uber.org/atomic v1.3.2 h1:2Oa65PReHzfn29GpvgsYwloV9AVFHPDk8tYxt2c2tr4=

planner/core/cacheable_checker.go

+14
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,20 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren
5151
checker.cacheable = false
5252
return in, true
5353
}
54+
case *ast.OrderByClause:
55+
for _, item := range node.Items {
56+
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
57+
checker.cacheable = false
58+
return in, true
59+
}
60+
}
61+
case *ast.GroupByClause:
62+
for _, item := range node.Items {
63+
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
64+
checker.cacheable = false
65+
return in, true
66+
}
67+
}
5468
case *ast.Limit:
5569
if node.Count != nil {
5670
if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker {

planner/core/cacheable_checker_test.go

+14
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,18 @@ func (s *testCacheableSuite) TestCacheable(c *C) {
8787
Limit: limitStmt,
8888
}
8989
c.Assert(Cacheable(stmt), IsTrue)
90+
91+
paramExpr := &driver.ParamMarkerExpr{}
92+
orderByClause := &ast.OrderByClause{Items: []*ast.ByItem{{Expr: paramExpr}}}
93+
stmt = &ast.SelectStmt{
94+
OrderBy: orderByClause,
95+
}
96+
c.Assert(Cacheable(stmt), IsFalse)
97+
98+
valExpr := &driver.ValueExpr{}
99+
orderByClause = &ast.OrderByClause{Items: []*ast.ByItem{{Expr: valExpr}}}
100+
stmt = &ast.SelectStmt{
101+
OrderBy: orderByClause,
102+
}
103+
c.Assert(Cacheable(stmt), IsTrue)
90104
}

planner/core/errors.go

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const (
2828

2929
codeWrongUsage = mysql.ErrWrongUsage
3030
codeAmbiguous = mysql.ErrNonUniq
31+
codeUnknown = mysql.ErrUnknown
3132
codeUnknownColumn = mysql.ErrBadField
3233
codeUnknownTable = mysql.ErrUnknownTable
3334
codeWrongArguments = mysql.ErrWrongArguments
@@ -65,6 +66,7 @@ var (
6566

6667
ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage])
6768
ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq])
69+
ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown])
6870
ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField])
6971
ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable])
7072
ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments])

planner/core/expression_rewriter.go

+23-8
Original file line numberDiff line numberDiff line change
@@ -820,11 +820,10 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
820820
value := &expression.Constant{Value: v.Datum, RetType: &v.Type}
821821
er.ctxStack = append(er.ctxStack, value)
822822
case *driver.ParamMarkerExpr:
823-
tp := types.NewFieldType(mysql.TypeUnspecified)
824-
types.DefaultParamTypeForValue(v.GetValue(), tp)
825-
value := &expression.Constant{Value: v.Datum, RetType: tp}
826-
if er.useCache() {
827-
value.DeferredExpr = er.getParamExpression(v)
823+
var value expression.Expression
824+
value, er.err = expression.GetParamExpression(er.ctx, v)
825+
if er.err != nil {
826+
return retNode, false
828827
}
829828
er.ctxStack = append(er.ctxStack, value)
830829
case *ast.VariableExpr:
@@ -1044,10 +1043,26 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) {
10441043
}
10451044

10461045
func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) {
1047-
if v.N > 0 && v.N <= er.schema.Len() {
1048-
er.ctxStack = append(er.ctxStack, er.schema.Columns[v.N-1])
1046+
pos := v.N
1047+
str := strconv.Itoa(pos)
1048+
if v.P != nil {
1049+
stkLen := len(er.ctxStack)
1050+
val := er.ctxStack[stkLen-1]
1051+
intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val)
1052+
str = "?"
1053+
if err == nil {
1054+
if isNull {
1055+
return
1056+
}
1057+
pos = intNum
1058+
er.ctxStack = er.ctxStack[:stkLen-1]
1059+
}
1060+
er.err = err
1061+
}
1062+
if er.err == nil && pos > 0 && pos <= er.schema.Len() {
1063+
er.ctxStack = append(er.ctxStack, er.schema.Columns[pos-1])
10491064
} else {
1050-
er.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[er.b.curClause])
1065+
er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause])
10511066
}
10521067
}
10531068

0 commit comments

Comments
 (0)