diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 4998aa24885c8..100e78a2ac6d7 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -232,7 +232,7 @@ func (b *builtinASCIISig) Clone() builtinFunc { return newSig } -// eval evals a builtinASCIISig. +// evalInt evals a builtinASCIISig. // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_ascii func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) { val, isNull, err := b.args[0].EvalString(b.ctx, row) @@ -285,6 +285,7 @@ func (b *builtinConcatSig) Clone() builtinFunc { return newSig } +// evalString evals a builtinConcatSig // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err error) { var s []byte @@ -568,7 +569,7 @@ func (b *builtinRepeatSig) Clone() builtinFunc { return newSig } -// eval evals a builtinRepeatSig. +// evalString evals a builtinRepeatSig. // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err error) { str, isNull, err := b.args[0].EvalString(b.ctx, row) @@ -1515,6 +1516,7 @@ type trimFunctionClass struct { baseFunctionClass } +// getFunction sets trim built-in function signature. // The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)', // but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST. func (c *trimFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { @@ -2482,8 +2484,8 @@ func (b *builtinOctStringSig) Clone() builtinFunc { return newSig } -// // evalString evals OCT(N). -// // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct +// evalString evals OCT(N). +// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct func (b *builtinOctStringSig) evalString(row chunk.Row) (string, bool, error) { val, isNull, err := b.args[0].EvalString(b.ctx, row) if isNull || err != nil { @@ -2999,17 +3001,26 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre } bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString) bf.tp.Flen = base64NeededEncodedLength(bf.args[0].GetType().Flen) - sig := &builtinToBase64Sig{bf} + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + + sig := &builtinToBase64Sig{bf, maxAllowedPacket} return sig, nil } type builtinToBase64Sig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinToBase64Sig) Clone() builtinFunc { newSig := &builtinToBase64Sig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -3043,7 +3054,14 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e if isNull || err != nil { return "", isNull, errors.Trace(err) } - + needEncodeLen := base64NeededEncodedLength(len(str)) + if needEncodeLen == -1 { + return "", true, nil + } + if needEncodeLen > int(b.maxAllowedPacket) { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("to_base64", b.maxAllowedPacket)) + return "", true, nil + } if b.tp.Flen == -1 || b.tp.Flen > mysql.MaxBlobWidth { return "", true, nil } diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 3d067da53a335..7341f7f1055dc 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -23,6 +23,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/charset" @@ -1879,6 +1880,74 @@ func (s *testEvaluatorSuite) TestToBase64(c *C) { c.Assert(err, IsNil) } +func (s *testEvaluatorSuite) TestToBase64Sig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + } + + tests := []struct { + args string + expect string + isNil bool + maxAllowPacket uint64 + }{ + {"abc", "YWJj", false, 4}, + {"abc", "", true, 3}, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==", + false, + 89, + }, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "", + true, + 88, + }, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv", + false, + 259, + }, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "", + true, + 258, + }, + } + + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + } + + for _, test := range tests { + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: base64NeededEncodedLength(len(test.args))} + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + toBase64 := &builtinToBase64Sig{base, test.maxAllowPacket} + + input := chunk.NewChunkWithCapacity(colTypes, 1) + input.AppendString(0, test.args) + res, isNull, err := toBase64.evalString(input.GetRow(0)) + c.Assert(err, IsNil) + if test.isNil { + c.Assert(isNull, IsTrue) + + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 1) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{}) + + } else { + c.Assert(isNull, IsFalse) + } + c.Assert(res, Equals, test.expect) + } +} + func (s *testEvaluatorSuite) TestStringRight(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.Right]