From 8ce2ad169613ec82d6e03a93861ead3a3e5d27f5 Mon Sep 17 00:00:00 2001 From: YangKeao Date: Thu, 9 Nov 2023 18:42:42 +0800 Subject: [PATCH] parse: fix the type of date/time parameters (#48237) close pingcap/tidb#45190 --- pkg/param/BUILD.bazel | 16 + pkg/param/binary_params.go | 275 +++++++++++++++++ pkg/server/BUILD.bazel | 10 +- pkg/server/conn_stmt.go | 10 +- pkg/server/conn_stmt_params.go | 130 ++++++++ pkg/server/conn_stmt_params_test.go | 380 ++++++++++++++++++++++++ pkg/server/extension.go | 16 +- pkg/server/internal/parse/BUILD.bazel | 13 - pkg/server/internal/parse/parse.go | 337 --------------------- pkg/server/internal/parse/parse_test.go | 231 -------------- pkg/session/BUILD.bazel | 1 + pkg/session/session.go | 11 + 12 files changed, 839 insertions(+), 591 deletions(-) create mode 100644 pkg/param/BUILD.bazel create mode 100644 pkg/param/binary_params.go create mode 100644 pkg/server/conn_stmt_params.go create mode 100644 pkg/server/conn_stmt_params_test.go diff --git a/pkg/param/BUILD.bazel b/pkg/param/BUILD.bazel new file mode 100644 index 0000000000000..1b2204b5e7f45 --- /dev/null +++ b/pkg/param/BUILD.bazel @@ -0,0 +1,16 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "param", + srcs = ["binary_params.go"], + importpath = "github.com/pingcap/tidb/pkg/param", + visibility = ["//visibility:public"], + deps = [ + "//pkg/errno", + "//pkg/expression", + "//pkg/parser/mysql", + "//pkg/types", + "//pkg/util/dbterror", + "//pkg/util/hack", + ], +) diff --git a/pkg/param/binary_params.go b/pkg/param/binary_params.go new file mode 100644 index 0000000000000..ce016f5016918 --- /dev/null +++ b/pkg/param/binary_params.go @@ -0,0 +1,275 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package param + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/hack" +) + +var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType) + +// BinaryParam stores the information decoded from the binary protocol +// It can be further parsed into `expression.Expression` through the `ExecArgs` function in this package +type BinaryParam struct { + Tp byte + IsUnsigned bool + IsNull bool + Val []byte +} + +// ExecArgs parse execute arguments to datum slice. +func ExecArgs(typectx types.Context, binaryParams []BinaryParam) (params []expression.Expression, err error) { + var ( + tmp interface{} + ) + + params = make([]expression.Expression, len(binaryParams)) + args := make([]types.Datum, len(binaryParams)) + for i := 0; i < len(args); i++ { + tp := binaryParams[i].Tp + isUnsigned := binaryParams[i].IsUnsigned + + switch tp { + case mysql.TypeNull: + var nilDatum types.Datum + nilDatum.SetNull() + args[i] = nilDatum + continue + + case mysql.TypeTiny: + if isUnsigned { + args[i] = types.NewUintDatum(uint64(binaryParams[i].Val[0])) + } else { + args[i] = types.NewIntDatum(int64(int8(binaryParams[i].Val[0]))) + } + continue + + case mysql.TypeShort, mysql.TypeYear: + valU16 := binary.LittleEndian.Uint16(binaryParams[i].Val) + if isUnsigned { + args[i] = types.NewUintDatum(uint64(valU16)) + } else { + args[i] = types.NewIntDatum(int64(int16(valU16))) + } + continue + + case mysql.TypeInt24, mysql.TypeLong: + valU32 := binary.LittleEndian.Uint32(binaryParams[i].Val) + if isUnsigned { + args[i] = types.NewUintDatum(uint64(valU32)) + } else { + args[i] = types.NewIntDatum(int64(int32(valU32))) + } + continue + + case mysql.TypeLonglong: + valU64 := binary.LittleEndian.Uint64(binaryParams[i].Val) + if isUnsigned { + args[i] = types.NewUintDatum(valU64) + } else { + args[i] = types.NewIntDatum(int64(valU64)) + } + continue + + case mysql.TypeFloat: + args[i] = types.NewFloat32Datum(math.Float32frombits(binary.LittleEndian.Uint32(binaryParams[i].Val))) + continue + + case mysql.TypeDouble: + args[i] = types.NewFloat64Datum(math.Float64frombits(binary.LittleEndian.Uint64(binaryParams[i].Val))) + continue + + case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime: + switch len(binaryParams[i].Val) { + case 0: + tmp = types.ZeroDatetimeStr + case 4: + _, tmp = binaryDate(0, binaryParams[i].Val) + case 7: + _, tmp = binaryDateTime(0, binaryParams[i].Val) + case 11: + _, tmp = binaryTimestamp(0, binaryParams[i].Val) + case 13: + _, tmp = binaryTimestampWithTZ(0, binaryParams[i].Val) + default: + err = mysql.ErrMalformPacket + return + } + // TODO: generate the time datum directly + var parseTime func(types.Context, string) (types.Time, error) + switch tp { + case mysql.TypeDate: + parseTime = types.ParseDate + case mysql.TypeDatetime: + parseTime = types.ParseDatetime + case mysql.TypeTimestamp: + // To be compatible with MySQL, even the type of parameter is + // TypeTimestamp, the return type should also be `Datetime`. + parseTime = types.ParseDatetime + } + var time types.Time + time, err = parseTime(typectx, tmp.(string)) + err = typectx.HandleTruncate(err) + if err != nil { + return + } + args[i] = types.NewDatum(time) + continue + + case mysql.TypeDuration: + switch len(binaryParams[i].Val) { + case 0: + tmp = "0" + case 8: + isNegative := binaryParams[i].Val[0] + if isNegative > 1 { + err = mysql.ErrMalformPacket + return + } + _, tmp = binaryDuration(1, binaryParams[i].Val, isNegative) + case 12: + isNegative := binaryParams[i].Val[0] + if isNegative > 1 { + err = mysql.ErrMalformPacket + return + } + _, tmp = binaryDurationWithMS(1, binaryParams[i].Val, isNegative) + default: + err = mysql.ErrMalformPacket + return + } + // TODO: generate the duration datum directly + var dur types.Duration + dur, _, err = types.ParseDuration(typectx, tmp.(string), types.MaxFsp) + err = typectx.HandleTruncate(err) + if err != nil { + return + } + args[i] = types.NewDatum(dur) + continue + case mysql.TypeNewDecimal: + if binaryParams[i].IsNull { + args[i] = types.NewDecimalDatum(nil) + } else { + var dec types.MyDecimal + err = typectx.HandleTruncate(dec.FromString(binaryParams[i].Val)) + if err != nil { + return nil, err + } + args[i] = types.NewDecimalDatum(&dec) + } + continue + case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + if binaryParams[i].IsNull { + args[i] = types.NewBytesDatum(nil) + } else { + args[i] = types.NewBytesDatum(binaryParams[i].Val) + } + continue + case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, + mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit: + if !binaryParams[i].IsNull { + tmp = string(hack.String(binaryParams[i].Val)) + } else { + tmp = nil + } + args[i] = types.NewDatum(tmp) + continue + default: + err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp) + return + } + } + + for i := range params { + ft := new(types.FieldType) + types.InferParamTypeFromUnderlyingValue(args[i].GetValue(), ft) + params[i] = &expression.Constant{Value: args[i], RetType: ft} + } + return +} + +func binaryDate(pos int, paramValues []byte) (int, string) { + year := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) + pos += 2 + month := paramValues[pos] + pos++ + day := paramValues[pos] + pos++ + return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day) +} + +func binaryDateTime(pos int, paramValues []byte) (int, string) { + pos, date := binaryDate(pos, paramValues) + hour := paramValues[pos] + pos++ + minute := paramValues[pos] + pos++ + second := paramValues[pos] + pos++ + return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second) +} + +func binaryTimestamp(pos int, paramValues []byte) (int, string) { + pos, dateTime := binaryDateTime(pos, paramValues) + microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) + pos += 4 + return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond) +} + +func binaryTimestampWithTZ(pos int, paramValues []byte) (int, string) { + pos, timestamp := binaryTimestamp(pos, paramValues) + tzShiftInMin := int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) + tzShiftHour := tzShiftInMin / 60 + tzShiftAbsMin := tzShiftInMin % 60 + if tzShiftAbsMin < 0 { + tzShiftAbsMin = -tzShiftAbsMin + } + pos += 2 + return pos, fmt.Sprintf("%s%+02d:%02d", timestamp, tzShiftHour, tzShiftAbsMin) +} + +func binaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) { + sign := "" + if isNegative == 1 { + sign = "-" + } + days := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) + pos += 4 + hours := paramValues[pos] + pos++ + minutes := paramValues[pos] + pos++ + seconds := paramValues[pos] + pos++ + return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds) +} + +func binaryDurationWithMS(pos int, paramValues []byte, + isNegative uint8) (int, string) { + pos, dur := binaryDuration(pos, paramValues, isNegative) + microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) + pos += 4 + return pos, fmt.Sprintf("%s.%06d", dur, microSecond) +} diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index f3a551b24bd1b..fbe3f826b0807 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -5,6 +5,7 @@ go_library( srcs = [ "conn.go", "conn_stmt.go", + "conn_stmt_params.go", "driver.go", "driver_tidb.go", "extension.go", @@ -32,6 +33,7 @@ go_library( "//pkg/infoschema", "//pkg/kv", "//pkg/metrics", + "//pkg/param", "//pkg/parser", "//pkg/parser/ast", "//pkg/parser/auth", @@ -76,6 +78,7 @@ go_library( "//pkg/util/arena", "//pkg/util/chunk", "//pkg/util/cpuprofile", + "//pkg/util/dbterror", "//pkg/util/dbterror/exeerrors", "//pkg/util/execdetails", "//pkg/util/fastrand", @@ -125,6 +128,7 @@ go_test( name = "server_test", timeout = "short", srcs = [ + "conn_stmt_params_test.go", "conn_stmt_test.go", "conn_test.go", "driver_tidb_test.go", @@ -138,20 +142,23 @@ go_test( data = glob(["testdata/**"]), embed = [":server"], flaky = True, - shard_count = 48, + shard_count = 50, deps = [ "//pkg/config", "//pkg/domain", "//pkg/domain/infosync", + "//pkg/expression", "//pkg/extension", "//pkg/keyspace", "//pkg/kv", "//pkg/metrics", + "//pkg/param", "//pkg/parser/ast", "//pkg/parser/auth", "//pkg/parser/charset", "//pkg/parser/model", "//pkg/parser/mysql", + "//pkg/parser/terror", "//pkg/server/internal", "//pkg/server/internal/column", "//pkg/server/internal/handshake", @@ -178,6 +185,7 @@ go_test( "//pkg/util/syncutil", "//pkg/util/topsql/state", "@com_github_docker_go_units//:go-units", + "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_stretchr_testify//require", diff --git a/pkg/server/conn_stmt.go b/pkg/server/conn_stmt.go index 93e7e1511f4d1..cfb6a20da506f 100644 --- a/pkg/server/conn_stmt.go +++ b/pkg/server/conn_stmt.go @@ -44,8 +44,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/param" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" @@ -180,7 +180,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e ) cc.initInputEncoder(ctx) numParams := stmt.NumParams() - args := make([]expression.Expression, numParams) + args := make([]param.BinaryParam, numParams) if numParams > 0 { nullBitmapLen := (numParams + 7) >> 3 if len(data) < (pos + nullBitmapLen + 1) { @@ -206,7 +206,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e paramValues = data[pos+1:] } - err = parse.ExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) + err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) // This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine) errReset := stmt.Reset() if errReset != nil { @@ -227,7 +227,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e return err } -func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []expression.Expression, useCursor bool) (err error) { +func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []param.BinaryParam, useCursor bool) (err error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) @@ -262,7 +262,7 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{} // The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried. // Currently the first return value is used to fallback to TiKV when TiFlash is down. -func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []expression.Expression, useCursor bool) (bool, error) { +func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []param.BinaryParam, useCursor bool) (bool, error) { vars := (&cc.ctx).GetSessionVars() prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID())) if err != nil { diff --git a/pkg/server/conn_stmt_params.go b/pkg/server/conn_stmt_params.go new file mode 100644 index 0000000000000..5b625c683faac --- /dev/null +++ b/pkg/server/conn_stmt_params.go @@ -0,0 +1,130 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + util2 "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror" +) + +var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType) + +// parseBinaryParams decodes the binary params according to the protocol +func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte, enc *util2.InputDecoder) (err error) { + pos := 0 + if enc == nil { + enc = util2.NewInputDecoder(charset.CharsetUTF8) + } + + for i := 0; i < len(params); i++ { + // if params had received via ComStmtSendLongData, use them directly. + // ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html + // see clientConn#handleStmtSendLongData + if boundParams[i] != nil { + params[i] = param.BinaryParam{ + Tp: mysql.TypeBlob, + Val: enc.DecodeInput(boundParams[i]), + } + continue + } + + // check nullBitMap to determine the NULL arguments. + // ref https://dev.mysql.com/doc/internals/en/com-stmt-execute.html + // notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData, + // so this check need place after boundParam's check. + if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { + var nilDatum types.Datum + nilDatum.SetNull() + params[i] = param.BinaryParam{ + Tp: mysql.TypeNull, + } + continue + } + + if (i<<1)+1 >= len(paramTypes) { + return mysql.ErrMalformPacket + } + + tp := paramTypes[i<<1] + isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 + isNull := false + + decodeWithDecoder := false + + var length uint64 + switch tp { + case mysql.TypeNull: + length = 0 + isNull = true + case mysql.TypeTiny: + length = 1 + case mysql.TypeShort, mysql.TypeYear: + length = 2 + case mysql.TypeInt24, mysql.TypeLong, mysql.TypeFloat: + length = 4 + case mysql.TypeLonglong, mysql.TypeDouble: + length = 8 + case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration: + if len(paramValues) < (pos + 1) { + err = mysql.ErrMalformPacket + return + } + length = uint64(paramValues[pos]) + pos++ + case mysql.TypeNewDecimal, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + if len(paramValues) < (pos + 1) { + err = mysql.ErrMalformPacket + return + } + var n int + length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:]) + pos += n + case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, + mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit: + if len(paramValues) < (pos + 1) { + err = mysql.ErrMalformPacket + return + } + var n int + length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:]) + pos += n + decodeWithDecoder = true + default: + err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp) + return + } + + if len(paramValues) < (pos + int(length)) { + err = mysql.ErrMalformPacket + return + } + params[i] = param.BinaryParam{ + Tp: tp, + IsUnsigned: isUnsigned, + IsNull: isNull, + Val: paramValues[pos : pos+int(length)], + } + if decodeWithDecoder { + params[i].Val = enc.DecodeInput(params[i].Val) + } + pos += int(length) + } + return +} diff --git a/pkg/server/conn_stmt_params_test.go b/pkg/server/conn_stmt_params_test.go new file mode 100644 index 0000000000000..beaa275692e52 --- /dev/null +++ b/pkg/server/conn_stmt_params_test.go @@ -0,0 +1,380 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "encoding/binary" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/server/internal/column" + "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/stretchr/testify/require" +) + +// decodeAndParse uses the `parseBinaryParams` and `parse.ExecArgs` to parse the params passed through binary protocol +// It helps to test the integration of these two functions +func decodeAndParse(typectx types.Context, args []expression.Expression, boundParams [][]byte, + nullBitmap, paramTypes, paramValues []byte, enc *util.InputDecoder) (err error) { + binParams := make([]param.BinaryParam, len(args)) + err = parseBinaryParams(binParams, boundParams, nullBitmap, paramTypes, paramValues, enc) + if err != nil { + return err + } + + parsedArgs, err := param.ExecArgs(typectx, binParams) + if err != nil { + return err + } + + for i := 0; i < len(args); i++ { + args[i] = parsedArgs[i] + } + return +} + +func TestParseExecArgs(t *testing.T) { + type args struct { + args []expression.Expression + boundParams [][]byte + nullBitmap []byte + paramTypes []byte + paramValues []byte + } + tests := []struct { + args args + err error + warn error + expect interface{} + }{ + // Tests for int overflow + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{1, 0}, + []byte{0xff}, + }, + nil, + nil, + int64(-1), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{2, 0}, + []byte{0xff, 0xff}, + }, + nil, + nil, + int64(-1), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{3, 0}, + []byte{0xff, 0xff, 0xff, 0xff}, + }, + nil, + nil, + int64(-1), + }, + // Tests for date/datetime/timestamp + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{12, 0}, + []byte{0x0b, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 19, 27, 30, 1), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{10, 0}, + []byte{0x04, 0xda, 0x07, 0x0a, 0x11}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 0, 0, 0, 0), mysql.TypeDate, types.DefaultFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x0b, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 19, 27, 30, 1), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x07, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 19, 27, 30, 0), mysql.TypeDatetime, types.DefaultFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x0d, 0xdb, 0x07, 0x02, 0x03, 0x04, 0x05, 0x06, 0x40, 0xe2, 0x01, 0x00, 0xf2, 0x02}, + }, + nil, + nil, + types.NewTime(types.FromDate(2011, 02, 02, 15, 31, 06, 123456), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x0d, 0xdb, 0x07, 0x02, 0x03, 0x04, 0x05, 0x06, 0x40, 0xe2, 0x01, 0x00, 0x0e, 0xfd}, + }, + nil, + nil, + types.NewTime(types.FromDate(2011, 02, 03, 16, 39, 06, 123456), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x00}, + }, + nil, + nil, + types.NewTime(types.ZeroCoreTime, mysql.TypeDatetime, types.DefaultFsp), + }, + // Tests for time + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{0x0c, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, + }, + nil, + types.ErrTruncatedWrongVal, + types.Duration{Duration: types.MinTime, Fsp: types.MaxFsp}, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{0x08, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e}, + }, + nil, + types.ErrTruncatedWrongVal, + types.Duration{Duration: types.MinTime, Fsp: types.MaxFsp}, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{0x00}, + }, + nil, + nil, + types.Duration{Duration: time.Duration(0), Fsp: types.MaxFsp}, + }, + // For error test + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{10}, + }, + mysql.ErrMalformPacket, + nil, + nil, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{10}, + }, + mysql.ErrMalformPacket, + nil, + nil, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{8, 2}, + }, + mysql.ErrMalformPacket, + nil, + nil, + }, + } + for _, tt := range tests { + var warn error + typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, func(err error) { + warn = err + }) + err := decodeAndParse(typectx, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) + require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err) + require.Truef(t, terror.ErrorEqual(warn, tt.warn), "warn %v", warn) + if err == nil { + require.Equal(t, tt.expect, tt.args.args[0].(*expression.Constant).Value.GetValue()) + } + } +} + +func TestParseExecArgsAndEncode(t *testing.T) { + dt := expression.Args2Expressions4Test(1) + err := decodeAndParse(types.DefaultStmtNoWarningContext, + dt, + [][]byte{nil}, + []byte{0x0}, + []byte{mysql.TypeVarchar, 0}, + []byte{4, 178, 226, 202, 212}, + util.NewInputDecoder("gbk")) + require.NoError(t, err) + require.Equal(t, "测试", dt[0].(*expression.Constant).Value.GetValue()) + + err = decodeAndParse(types.DefaultStmtNoWarningContext, + dt, + [][]byte{{178, 226, 202, 212}}, + []byte{0x0}, + []byte{mysql.TypeString, 0}, + []byte{}, + util.NewInputDecoder("gbk")) + require.NoError(t, err) + require.Equal(t, "测试", dt[0].(*expression.Constant).Value.GetString()) +} + +func buildDatetimeParam(year uint16, month uint8, day uint8, hour uint8, min uint8, sec uint8, msec uint32) []byte { + endian := binary.LittleEndian + + result := []byte{mysql.TypeDatetime, 0x0, 11} + result = endian.AppendUint16(result, year) + result = append(result, month) + result = append(result, day) + result = append(result, hour) + result = append(result, min) + result = append(result, sec) + result = endian.AppendUint32(result, msec) + return result +} + +func expectedDatetimeExecuteResult(t *testing.T, c *mockConn, time types.Time, warnCount int) []byte { + return getExpectOutput(t, c, func(conn *clientConn) { + var err error + + cols := []*column.Info{{ + Name: "t", + Table: "", + Type: mysql.TypeDatetime, + Charset: uint16(mysql.CharsetNameToID(charset.CharsetBin)), + Flag: uint16(mysql.NotNullFlag | mysql.BinaryFlag), + Decimal: 6, + ColumnLength: 26, + }} + require.NoError(t, conn.writeColumnInfo(cols)) + + chk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDatetime)}, 1) + chk.AppendTime(0, time) + data := make([]byte, 4) + data, err = column.DumpBinaryRow(data, cols, chk.GetRow(0), conn.rsEncoder) + require.NoError(t, err) + require.NoError(t, conn.writePacket(data)) + + for i := 0; i < warnCount; i++ { + conn.ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("any error")) + } + require.NoError(t, conn.writeEOF(context.Background(), mysql.ServerStatusAutocommit)) + }) +} + +func TestDateTimeTypes(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + c.capability = mysql.ClientProtocol41 | mysql.ClientDeprecateEOF + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + stmt, _, _, err := c.Context().Prepare("select ? as t") + require.NoError(t, err) + + expectedTimeDatum, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, "2023-11-09 14:23:45.000100") + require.NoError(t, err) + expected := expectedDatetimeExecuteResult(t, c, expectedTimeDatum, 1) + + // execute the statement with datetime parameter + req := append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + 0x0, 0x1, 0x0, 0x0, 0x0, + 0x0, 0x1, + ) + req = append(req, buildDatetimeParam(2023, 11, 9, 14, 23, 45, 100)...) + out := c.GetOutput() + require.NoError(t, c.Dispatch(ctx, req)) + + require.Equal(t, expected, out.Bytes()) +} diff --git a/pkg/server/extension.go b/pkg/server/extension.go index 2ffc23047613e..407c4fa979e8b 100644 --- a/pkg/server/extension.go +++ b/pkg/server/extension.go @@ -17,8 +17,8 @@ package server import ( "fmt" - "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/extension" + "github.com/pingcap/tidb/pkg/param" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/auth" @@ -57,7 +57,7 @@ func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) cc.extensions.OnConnectionEvent(tp, info) } -func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, err error, args ...expression.Expression) { +func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, err error, args ...param.BinaryParam) { if !cc.extensions.HasStmtEventListeners() { return } @@ -85,9 +85,17 @@ func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, er case PreparedStatement: info.executeStmtID = uint32(stmt.ID()) prepared, _ := sessVars.GetPreparedStmtByID(info.executeStmtID) + + // TODO: the `BinaryParam` is parsed two times: one in the `Execute` method and one here. It would be better to + // eliminate one of them by storing the parsed result. + typectx := ctx.GetSessionVars().StmtCtx.TypeCtx() + typectx = types.NewContext(typectx.Flags(), typectx.Location(), func(_ error) { + // ignore all warnings + }) + params, _ := param.ExecArgs(typectx, args) info.executeStmt = &ast.ExecuteStmt{ PrepStmt: prepared, - BinaryArgs: args, + BinaryArgs: params, } info.stmtNode = info.executeStmt case ast.StmtNode: @@ -115,7 +123,7 @@ func (cc *clientConn) onExtensionSQLParseFailed(sql string, err error) { }) } -func (cc *clientConn) onExtensionBinaryExecuteEnd(prep PreparedStatement, args []expression.Expression, stmtCtxValid bool, err error) { +func (cc *clientConn) onExtensionBinaryExecuteEnd(prep PreparedStatement, args []param.BinaryParam, stmtCtxValid bool, err error) { cc.onExtensionStmtEnd(prep, stmtCtxValid, err, args...) } diff --git a/pkg/server/internal/parse/BUILD.bazel b/pkg/server/internal/parse/BUILD.bazel index 5dd9a9d717262..dfab3da1f7923 100644 --- a/pkg/server/internal/parse/BUILD.bazel +++ b/pkg/server/internal/parse/BUILD.bazel @@ -6,16 +6,9 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/server/internal/parse", visibility = ["//pkg/server:__subpackages__"], deps = [ - "//pkg/errno", - "//pkg/expression", - "//pkg/parser/charset", "//pkg/parser/mysql", "//pkg/server/internal/handshake", "//pkg/server/internal/util", - "//pkg/sessionctx/stmtctx", - "//pkg/types", - "//pkg/util/dbterror", - "//pkg/util/hack", "//pkg/util/logutil", "@com_github_klauspost_compress//zstd", "@org_uber_go_zap//:zap", @@ -31,15 +24,9 @@ go_test( ], embed = [":parse"], flaky = True, - shard_count = 4, deps = [ - "//pkg/expression", "//pkg/parser/mysql", - "//pkg/parser/terror", "//pkg/server/internal/handshake", - "//pkg/server/internal/util", - "//pkg/sessionctx/stmtctx", - "//pkg/types", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/server/internal/parse/parse.go b/pkg/server/internal/parse/parse.go index e55ca68eb6656..de7571d2c287c 100644 --- a/pkg/server/internal/parse/parse.go +++ b/pkg/server/internal/parse/parse.go @@ -18,357 +18,20 @@ import ( "bytes" "context" "encoding/binary" - "fmt" - "math" "github.com/klauspost/compress/zstd" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/server/internal/handshake" util2 "github.com/pingcap/tidb/pkg/server/internal/util" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/hack" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) -var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType) - // maxFetchSize constants const ( maxFetchSize = 1024 ) -// ExecArgs parse execute arguments to datum slice. -func ExecArgs(sc *stmtctx.StatementContext, params []expression.Expression, boundParams [][]byte, - nullBitmap, paramTypes, paramValues []byte, enc *util2.InputDecoder) (err error) { - pos := 0 - var ( - tmp interface{} - v []byte - n int - isNull bool - ) - if enc == nil { - enc = util2.NewInputDecoder(charset.CharsetUTF8) - } - - args := make([]types.Datum, len(params)) - for i := 0; i < len(args); i++ { - // if params had received via ComStmtSendLongData, use them directly. - // ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html - // see clientConn#handleStmtSendLongData - if boundParams[i] != nil { - args[i] = types.NewBytesDatum(enc.DecodeInput(boundParams[i])) - continue - } - - // check nullBitMap to determine the NULL arguments. - // ref https://dev.mysql.com/doc/internals/en/com-stmt-execute.html - // notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData, - // so this check need place after boundParam's check. - if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { - var nilDatum types.Datum - nilDatum.SetNull() - args[i] = nilDatum - continue - } - - if (i<<1)+1 >= len(paramTypes) { - return mysql.ErrMalformPacket - } - - tp := paramTypes[i<<1] - isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 - - switch tp { - case mysql.TypeNull: - var nilDatum types.Datum - nilDatum.SetNull() - args[i] = nilDatum - continue - - case mysql.TypeTiny: - if len(paramValues) < (pos + 1) { - err = mysql.ErrMalformPacket - return - } - - if isUnsigned { - args[i] = types.NewUintDatum(uint64(paramValues[pos])) - } else { - args[i] = types.NewIntDatum(int64(int8(paramValues[pos]))) - } - - pos++ - continue - - case mysql.TypeShort, mysql.TypeYear: - if len(paramValues) < (pos + 2) { - err = mysql.ErrMalformPacket - return - } - valU16 := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) - if isUnsigned { - args[i] = types.NewUintDatum(uint64(valU16)) - } else { - args[i] = types.NewIntDatum(int64(int16(valU16))) - } - pos += 2 - continue - - case mysql.TypeInt24, mysql.TypeLong: - if len(paramValues) < (pos + 4) { - err = mysql.ErrMalformPacket - return - } - valU32 := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) - if isUnsigned { - args[i] = types.NewUintDatum(uint64(valU32)) - } else { - args[i] = types.NewIntDatum(int64(int32(valU32))) - } - pos += 4 - continue - - case mysql.TypeLonglong: - if len(paramValues) < (pos + 8) { - err = mysql.ErrMalformPacket - return - } - valU64 := binary.LittleEndian.Uint64(paramValues[pos : pos+8]) - if isUnsigned { - args[i] = types.NewUintDatum(valU64) - } else { - args[i] = types.NewIntDatum(int64(valU64)) - } - pos += 8 - continue - - case mysql.TypeFloat: - if len(paramValues) < (pos + 4) { - err = mysql.ErrMalformPacket - return - } - - args[i] = types.NewFloat32Datum(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))) - pos += 4 - continue - - case mysql.TypeDouble: - if len(paramValues) < (pos + 8) { - err = mysql.ErrMalformPacket - return - } - - args[i] = types.NewFloat64Datum(math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))) - pos += 8 - continue - - case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime: - if len(paramValues) < (pos + 1) { - err = mysql.ErrMalformPacket - return - } - // See https://dev.mysql.com/doc/internals/en/binary-protocol-value.html - // for more details. - length := paramValues[pos] - pos++ - switch length { - case 0: - tmp = types.ZeroDatetimeStr - case 4: - pos, tmp = binaryDate(pos, paramValues) - case 7: - pos, tmp = binaryDateTime(pos, paramValues) - case 11: - pos, tmp = binaryTimestamp(pos, paramValues) - case 13: - pos, tmp = binaryTimestampWithTZ(pos, paramValues) - default: - err = mysql.ErrMalformPacket - return - } - args[i] = types.NewDatum(tmp) // FIXME: After check works!!!!!! - continue - - case mysql.TypeDuration: - if len(paramValues) < (pos + 1) { - err = mysql.ErrMalformPacket - return - } - // See https://dev.mysql.com/doc/internals/en/binary-protocol-value.html - // for more details. - length := paramValues[pos] - pos++ - switch length { - case 0: - tmp = "0" - case 8: - isNegative := paramValues[pos] - if isNegative > 1 { - err = mysql.ErrMalformPacket - return - } - pos++ - pos, tmp = binaryDuration(pos, paramValues, isNegative) - case 12: - isNegative := paramValues[pos] - if isNegative > 1 { - err = mysql.ErrMalformPacket - return - } - pos++ - pos, tmp = binaryDurationWithMS(pos, paramValues, isNegative) - default: - err = mysql.ErrMalformPacket - return - } - args[i] = types.NewDatum(tmp) - continue - case mysql.TypeNewDecimal: - if len(paramValues) < (pos + 1) { - err = mysql.ErrMalformPacket - return - } - - v, isNull, n, err = util2.ParseLengthEncodedBytes(paramValues[pos:]) - pos += n - if err != nil { - return - } - - if isNull { - args[i] = types.NewDecimalDatum(nil) - } else { - var dec types.MyDecimal - err = sc.HandleTruncate(dec.FromString(v)) - if err != nil { - return err - } - args[i] = types.NewDecimalDatum(&dec) - } - continue - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - if len(paramValues) < (pos + 1) { - err = mysql.ErrMalformPacket - return - } - v, isNull, n, err = util2.ParseLengthEncodedBytes(paramValues[pos:]) - pos += n - if err != nil { - return - } - - if isNull { - args[i] = types.NewBytesDatum(nil) - } else { - args[i] = types.NewBytesDatum(v) - } - continue - case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, - mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit: - if len(paramValues) < (pos + 1) { - err = mysql.ErrMalformPacket - return - } - - v, isNull, n, err = util2.ParseLengthEncodedBytes(paramValues[pos:]) - pos += n - if err != nil { - return - } - - if !isNull { - v = enc.DecodeInput(v) - tmp = string(hack.String(v)) - } else { - tmp = nil - } - args[i] = types.NewDatum(tmp) - continue - default: - err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp) - return - } - } - - for i := range params { - ft := new(types.FieldType) - types.InferParamTypeFromUnderlyingValue(args[i].GetValue(), ft) - params[i] = &expression.Constant{Value: args[i], RetType: ft} - } - return -} - -func binaryDate(pos int, paramValues []byte) (int, string) { - year := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) - pos += 2 - month := paramValues[pos] - pos++ - day := paramValues[pos] - pos++ - return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day) -} - -func binaryDateTime(pos int, paramValues []byte) (int, string) { - pos, date := binaryDate(pos, paramValues) - hour := paramValues[pos] - pos++ - minute := paramValues[pos] - pos++ - second := paramValues[pos] - pos++ - return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second) -} - -func binaryTimestamp(pos int, paramValues []byte) (int, string) { - pos, dateTime := binaryDateTime(pos, paramValues) - microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) - pos += 4 - return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond) -} - -func binaryTimestampWithTZ(pos int, paramValues []byte) (int, string) { - pos, timestamp := binaryTimestamp(pos, paramValues) - tzShiftInMin := int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) - tzShiftHour := tzShiftInMin / 60 - tzShiftAbsMin := tzShiftInMin % 60 - if tzShiftAbsMin < 0 { - tzShiftAbsMin = -tzShiftAbsMin - } - pos += 2 - return pos, fmt.Sprintf("%s%+02d:%02d", timestamp, tzShiftHour, tzShiftAbsMin) -} - -func binaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) { - sign := "" - if isNegative == 1 { - sign = "-" - } - days := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) - pos += 4 - hours := paramValues[pos] - pos++ - minutes := paramValues[pos] - pos++ - seconds := paramValues[pos] - pos++ - return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds) -} - -func binaryDurationWithMS(pos int, paramValues []byte, - isNegative uint8) (int, string) { - pos, dur := binaryDuration(pos, paramValues, isNegative) - microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) - pos += 4 - return pos, fmt.Sprintf("%s.%06d", dur, microSecond) -} - // StmtFetchCmd parse COM_STMT_FETCH command func StmtFetchCmd(data []byte) (stmtID uint32, fetchSize uint32, err error) { if len(data) != 8 { diff --git a/pkg/server/internal/parse/parse_test.go b/pkg/server/internal/parse/parse_test.go index b8345cceca2eb..cc44038cb50a5 100644 --- a/pkg/server/internal/parse/parse_test.go +++ b/pkg/server/internal/parse/parse_test.go @@ -17,241 +17,10 @@ package parse import ( "testing" - "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/server/internal/util" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/types" "github.com/stretchr/testify/require" ) -func TestParseExecArgs(t *testing.T) { - type args struct { - args []expression.Expression - boundParams [][]byte - nullBitmap []byte - paramTypes []byte - paramValues []byte - } - tests := []struct { - args args - err error - expect interface{} - }{ - // Tests for int overflow - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{1, 0}, - []byte{0xff}, - }, - nil, - int64(-1), - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{2, 0}, - []byte{0xff, 0xff}, - }, - nil, - int64(-1), - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{3, 0}, - []byte{0xff, 0xff, 0xff, 0xff}, - }, - nil, - int64(-1), - }, - // Tests for date/datetime/timestamp - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{12, 0}, - []byte{0x0b, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, - }, - nil, - "2010-10-17 19:27:30.000001", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{10, 0}, - []byte{0x04, 0xda, 0x07, 0x0a, 0x11}, - }, - nil, - "2010-10-17", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{7, 0}, - []byte{0x0b, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, - }, - nil, - "2010-10-17 19:27:30.000001", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{7, 0}, - []byte{0x07, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e}, - }, - nil, - "2010-10-17 19:27:30", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{7, 0}, - []byte{0x0d, 0xdb, 0x07, 0x02, 0x03, 0x04, 0x05, 0x06, 0x40, 0xe2, 0x01, 0x00, 0xf2, 0x02}, - }, - nil, - "2011-02-03 04:05:06.123456+12:34", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{7, 0}, - []byte{0x0d, 0xdb, 0x07, 0x02, 0x03, 0x04, 0x05, 0x06, 0x40, 0xe2, 0x01, 0x00, 0x0e, 0xfd}, - }, - nil, - "2011-02-03 04:05:06.123456-12:34", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{7, 0}, - []byte{0x00}, - }, - nil, - types.ZeroDatetimeStr, - }, - // Tests for time - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{11, 0}, - []byte{0x0c, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, - }, - nil, - "-120 19:27:30.000001", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{11, 0}, - []byte{0x08, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e}, - }, - nil, - "-120 19:27:30", - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{11, 0}, - []byte{0x00}, - }, - nil, - "0", - }, - // For error test - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{7, 0}, - []byte{10}, - }, - mysql.ErrMalformPacket, - nil, - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{11, 0}, - []byte{10}, - }, - mysql.ErrMalformPacket, - nil, - }, - { - args{ - expression.Args2Expressions4Test(1), - [][]byte{nil}, - []byte{0x0}, - []byte{11, 0}, - []byte{8, 2}, - }, - mysql.ErrMalformPacket, - nil, - }, - } - for _, tt := range tests { - err := ExecArgs(stmtctx.NewStmtCtx(), tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) - require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err) - if err == nil { - require.Equal(t, tt.expect, tt.args.args[0].(*expression.Constant).Value.GetValue()) - } - } -} - -func TestParseExecArgsAndEncode(t *testing.T) { - dt := expression.Args2Expressions4Test(1) - err := ExecArgs(stmtctx.NewStmtCtx(), - dt, - [][]byte{nil}, - []byte{0x0}, - []byte{mysql.TypeVarchar, 0}, - []byte{4, 178, 226, 202, 212}, - util.NewInputDecoder("gbk")) - require.NoError(t, err) - require.Equal(t, "测试", dt[0].(*expression.Constant).Value.GetValue()) - - err = ExecArgs(stmtctx.NewStmtCtx(), - dt, - [][]byte{{178, 226, 202, 212}}, - []byte{0x0}, - []byte{mysql.TypeString, 0}, - []byte{}, - util.NewInputDecoder("gbk")) - require.NoError(t, err) - require.Equal(t, "测试", dt[0].(*expression.Constant).Value.GetString()) -} - func TestParseStmtFetchCmd(t *testing.T) { tests := []struct { arg []byte diff --git a/pkg/session/BUILD.bazel b/pkg/session/BUILD.bazel index 8cc1ed1a24e21..5a27a9785a923 100644 --- a/pkg/session/BUILD.bazel +++ b/pkg/session/BUILD.bazel @@ -35,6 +35,7 @@ go_library( "//pkg/meta", "//pkg/metrics", "//pkg/owner", + "//pkg/param", "//pkg/parser", "//pkg/parser/ast", "//pkg/parser/auth", diff --git a/pkg/session/session.go b/pkg/session/session.go index 102c0b8f164c7..c50451f56d368 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -55,6 +55,7 @@ import ( "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/param" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/auth" @@ -2140,6 +2141,16 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { return nil, err } + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + if binParam, ok := execStmt.BinaryArgs.([]param.BinaryParam); ok { + args, err := param.ExecArgs(s.GetSessionVars().StmtCtx.TypeCtx(), binParam) + if err != nil { + return nil, err + } + execStmt.BinaryArgs = args + } + } + normalizedSQL, digest := s.sessionVars.StmtCtx.SQLDigest() cmdByte := byte(atomic.LoadUint32(&s.GetSessionVars().CommandValue)) if topsqlstate.TopSQLEnabled() {