Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: handle max_allowed_packet warnings for from_base64 function. #7409

Merged
33 changes: 32 additions & 1 deletion expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -2959,18 +2959,39 @@ func (c *fromBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Exp
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString)
bf.tp.Flen = mysql.MaxBlobWidth

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

types.SetBinChsClnFlag(bf.tp)
sig := &builtinFromBase64Sig{bf}
sig := &builtinFromBase64Sig{bf, maxAllowedPacket}
return sig, nil
}

// base64NeededDecodedLength return the base64 decoded string length.
func base64NeededDecodedLength(n int) int {
// Returns -1 indicate the result will overflow.
if strconv.IntSize == 64 && n > math.MaxInt64/3 {
return -1
}
if strconv.IntSize == 32 && n > math.MaxInt32/3 {
return -1
}
return n * 3 / 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about n/4*3, then the above validation can be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when n%4 > 4/3 , n/4*3 is not equal to n*3/4.
mysql calculates decode length by n*3/4, I think we'd better use the same formula.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

}

type builtinFromBase64Sig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinFromBase64Sig) Clone() builtinFunc {
newSig := &builtinFromBase64Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -2981,6 +3002,16 @@ func (b *builtinFromBase64Sig) evalString(row chunk.Row) (string, bool, error) {
if isNull || err != nil {
return "", true, errors.Trace(err)
}

needDecodeLen := base64NeededDecodedLength(len(str))
if needDecodeLen == -1 {
return "", true, nil
}
if needDecodeLen > int(b.maxAllowedPacket) {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("from_base64", b.maxAllowedPacket))
return "", true, nil
}

str = strings.Replace(str, "\t", "", -1)
str = strings.Replace(str, " ", "", -1)
result, err := base64.StdEncoding.DecodeString(str)
Expand Down
56 changes: 56 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,62 @@ func (s *testEvaluatorSuite) TestFromBase64(c *C) {
}
}

func (s *testEvaluatorSuite) TestFromBase64Sig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
}

tests := []struct {
args string
expect string
isNil bool
maxAllowPacket uint64
}{
{string("YWJj"), string("abc"), false, 3},
{string("YWJj"), "", true, 2},
{
string("QUJDREVGR0hJSkt\tMTU5PUFFSU1RVVld\nYWVphYmNkZ\rWZnaGlqa2xt bm9wcXJzdHV2d3h5ejAxMjM0NTY3ODkrLw=="),
string("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"),
false,
70,
},
{
string("QUJDREVGR0hJSkt\tMTU5PUFFSU1RVVld\nYWVphYmNkZ\rWZnaGlqa2xt bm9wcXJzdHV2d3h5ejAxMjM0NTY3ODkrLw=="),
"",
true,
69,
},
}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
}

for _, test := range tests {
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: mysql.MaxBlobWidth}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
fromBase64 := &builtinFromBase64Sig{base, test.maxAllowPacket}

input := chunk.NewChunkWithCapacity(colTypes, 1)
input.AppendString(0, test.args)
res, isNull, err := fromBase64.evalString(input.GetRow(0))
c.Assert(err, IsNil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add c.Assert(isNull, Equals, test.isNil)
then we can remove line #1653 and else branch in line #1662

if test.isNil {
c.Assert(isNull, IsTrue)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{})

} else {
c.Assert(isNull, IsFalse)
}
c.Assert(res, Equals, test.expect)
}
}

func (s *testEvaluatorSuite) TestInsert(c *C) {
tests := []struct {
args []interface{}
Expand Down