Skip to content

Commit

Permalink
expression: add max_allowed_packet check in concat/concat_ws (#11137)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyangfei authored and SunRunAway committed Jul 16, 2019
1 parent eae30eb commit 593fb7d
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 12 deletions.
55 changes: 43 additions & 12 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,17 +272,26 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
if bf.tp.Flen >= mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinConcatSig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatSig) Clone() builtinFunc {
newSig := &builtinConcatSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -295,6 +304,10 @@ func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err
if isNull || err != nil {
return d, isNull, err
}
if uint64(len(s)+len(d)) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat", b.maxAllowedPacket))
return "", true, nil
}
s = append(s, []byte(d)...)
}
return string(s), false, nil
Expand Down Expand Up @@ -337,17 +350,25 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Flen = mysql.MaxBlobWidth
}

sig := &builtinConcatWSSig{bf}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatWSSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatWSSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatWSSig) Clone() builtinFunc {
newSig := &builtinConcatWSSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -357,25 +378,35 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) {
args := b.getArgs()
strs := make([]string, 0, len(args))
var sep string
for i, arg := range args {
val, isNull, err := arg.EvalString(b.ctx, row)
var targetLength int

N := len(args)
if N > 0 {
val, isNull, err := args[0].EvalString(b.ctx, row)
if err != nil || isNull {
// If the separator is NULL, the result is NULL.
return val, isNull, err
}
sep = val
}
for i := 1; i < N; i++ {
val, isNull, err := args[i].EvalString(b.ctx, row)
if err != nil {
return val, isNull, err
}

if isNull {
// If the separator is NULL, the result is NULL.
if i == 0 {
return val, isNull, nil
}
// CONCAT_WS() does not skip empty strings. However,
// it does skip any NULL values after the separator argument.
continue
}

if i == 0 {
sep = val
continue
targetLength += len(val)
if i > 1 {
targetLength += len(sep)
}
if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat_ws", b.maxAllowedPacket))
return "", true, nil
}
strs = append(strs, val)
}
Expand Down
91 changes: 91 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,50 @@ func (s *testEvaluatorSuite) TestConcat(c *C) {
}
}

func (s *testEvaluatorSuite) TestConcatSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatSig{base, 5}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{"a", "b"}, 0, "ab"},
{[]interface{}{"aaa", "bbb"}, 1, ""},
{[]interface{}{"中", "a"}, 0, "中a"},
{[]interface{}{"中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestConcatWS(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
Expand Down Expand Up @@ -246,6 +290,53 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestConcatWSSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatWSSig{base, 6}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{",", "a", "b"}, 0, "a,b"},
{[]interface{}{",", "aaa", "bbb"}, 1, ""},
{[]interface{}{",", "中", "a"}, 0, "中,a"},
{[]interface{}{",", "中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))
input.AppendString(2, t.args[2].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestLeft(c *C) {
defer testleak.AfterTest(c)()
stmtCtx := s.ctx.GetSessionVars().StmtCtx
Expand Down
3 changes: 3 additions & 0 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ func NewContext() *Context {
sctx.sessionVars.MaxChunkSize = 32
sctx.sessionVars.StmtCtx.TimeZone = time.UTC
sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor()
if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil {
panic(err)
}
return sctx
}

Expand Down

0 comments on commit 593fb7d

Please sign in to comment.