Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: handle max_allowed_packet warnings for to_base64 function. #7266

Merged
merged 17 commits into from
Aug 15, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (b *builtinASCIISig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinASCIISig.
// evalInt evals a builtinASCIISig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_ascii
func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -285,6 +285,7 @@ func (b *builtinConcatSig) Clone() builtinFunc {
return newSig
}

// evalString evals a builtinConcatSig
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat
func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err error) {
var s []byte
Expand Down Expand Up @@ -568,7 +569,7 @@ func (b *builtinRepeatSig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinRepeatSig.
// evalString evals a builtinRepeatSig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat
func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err error) {
str, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -1515,6 +1516,7 @@ type trimFunctionClass struct {
baseFunctionClass
}

// getFunction sets trim built-in function signature.
// The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)',
// but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST.
func (c *trimFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -2482,8 +2484,8 @@ func (b *builtinOctStringSig) Clone() builtinFunc {
return newSig
}

// // evalString evals OCT(N).
// // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
// evalString evals OCT(N).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
func (b *builtinOctStringSig) evalString(row chunk.Row) (string, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
Expand Down Expand Up @@ -2999,17 +3001,26 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString)
bf.tp.Flen = base64NeededEncodedLength(bf.args[0].GetType().Flen)
sig := &builtinToBase64Sig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

sig := &builtinToBase64Sig{bf, maxAllowedPacket}
return sig, nil
}

type builtinToBase64Sig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinToBase64Sig) Clone() builtinFunc {
newSig := &builtinToBase64Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3044,6 +3055,10 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e
return "", isNull, errors.Trace(err)
}

if b.tp.Flen*mysql.MaxBytesOfCharacter > int(b.maxAllowedPacket) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not correct to use Flen to decide whether the result of ToBase64 exceeds max_allowed_packet, I think we can learn from MySQL:

String *Item_func_to_base64::val_str_ascii(String *str) {
  String *res = args[0]->val_str(str);
  bool too_long = false;
  uint64 length;
  if (!res || res->length() > (uint)base64_encode_max_arg_length() ||
      (too_long =
           ((length = base64_needed_encoded_length((uint64)res->length())) >
            current_thd->variables.max_allowed_packet)) ||
      tmp_value.alloc((uint)length)) {
    null_value = 1;  // NULL input, too long input, or OOM.
    if (too_long) {
      push_warning_printf(
          current_thd, Sql_condition::SL_WARNING,
          ER_WARN_ALLOWED_PACKET_OVERFLOWED,
          ER_THD(current_thd, ER_WARN_ALLOWED_PACKET_OVERFLOWED), func_name(),
          current_thd->variables.max_allowed_packet);
    }
    return 0;
  }
  base64_encode(res->ptr(), (int)res->length(), (char *)tmp_value.ptr());
  DBUG_ASSERT(length > 0);
  tmp_value.length((uint)length - 1);  // Without trailing '\0'
  null_value = 0;
  return &tmp_value;
}

The above code is taken from: https://github.com/mysql/mysql-server/blob/5.7/sql/item_strfunc.cc#L690

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, thanks, I will correct it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @supernan1994 any update?

Copy link
Contributor Author

@supernan1994 supernan1994 Aug 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry!!! I had a bussiness trip and was back yesterday, I will take a look at it tonight and tomorrow night @zz-jason

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄 take your time, it's just a just a friendly Ping

b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("to_base64", b.maxAllowedPacket))
return "", true, nil
}
if b.tp.Flen == -1 || b.tp.Flen > mysql.MaxBlobWidth {
return "", true, nil
}
Expand Down
78 changes: 78 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,84 @@ func (s *testEvaluatorSuite) TestToBase64(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestToBase64Sig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
}

tests := []struct {
args string
expect string
isNil bool
getErr bool
maxAllowPacket uint64
}{
{"abc", "YWJj", false, false, 16},
{"abc", "", true, false, 15},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==",
false,
false,
356,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
false,
355,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv",
false,
false,
1036,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
false,
1035,
},
}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
}

warningCount := 0

for _, test := range tests {
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: base64NeededEncodedLength(len(test.args))}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
toBase64 := &builtinToBase64Sig{base, test.maxAllowPacket}

input := chunk.NewChunkWithCapacity(colTypes, 1)
input.AppendString(0, test.args)
res, isNull, err := toBase64.evalString(input.GetRow(0))
if test.getErr {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the added test.getErr is false, maybe this check can be replaced with: c.Assert(err, IsNil)

c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
}
if test.isNil {
c.Assert(isNull, IsTrue)
warningCount += 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warningCount will add 1 when to_base64 result is nil

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reset the warning appended to s.ctx.GetSessionVars().StmtCtx and check the exactly warning count and warning content in each test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

} else {
c.Assert(isNull, IsFalse)
}
c.Assert(res, Equals, test.expect)
}
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, warningCount)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warningCount is always zero, seems the test of max_allowed_packets is not working

for _, warn := range warnings {
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, warn.Err), IsTrue)
}
}

func (s *testEvaluatorSuite) TestStringRight(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.Right]
Expand Down