Skip to content

Commit

Permalink
server: support decoding prepared string args to character_set_client (
Browse files Browse the repository at this point in the history
  • Loading branch information
tangenta authored Dec 15, 2021
1 parent 22418cd commit 04a9618
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
10 changes: 10 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ type clientConn struct {
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets.
inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8.
socketCredUID uint32 // UID from the other end of the Unix Socket
// mu is used for cancelling the execution of current transaction.
mu struct {
Expand Down Expand Up @@ -964,6 +965,15 @@ func (cc *clientConn) initResultEncoder(ctx context.Context) {
cc.rsEncoder = newResultEncoder(chs)
}

func (cc *clientConn) initInputEncoder(ctx context.Context) {
chs, err := variable.GetSessionOrGlobalSystemVar(cc.ctx.GetSessionVars(), variable.CharacterSetClient)
if err != nil {
chs = ""
logutil.Logger(ctx).Warn("get character_set_client system variable failed", zap.Error(err))
}
cc.inputDecoder = newInputDecoder(chs)
}

// initConnect runs the initConnect SQL statement if it has been specified.
// The semantics are MySQL compatible.
func (cc *clientConn) initConnect(ctx context.Context) error {
Expand Down
12 changes: 10 additions & 2 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
Expand Down Expand Up @@ -167,6 +168,8 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramTypes []byte
paramValues []byte
)
cc.initInputEncoder(ctx)
defer cc.inputDecoder.clean()
numParams := stmt.NumParams()
args := make([]types.Datum, numParams)
if numParams > 0 {
Expand Down Expand Up @@ -194,7 +197,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramValues = data[pos+1:]
}

err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues)
err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
stmt.Reset()
if err != nil {
return errors.Annotate(err, cc.preparedStmt2String(stmtID))
Expand Down Expand Up @@ -310,14 +313,18 @@ func parseStmtFetchCmd(data []byte) (uint32, uint32, error) {
return stmtID, fetchSize, nil
}

func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) {
func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte,
nullBitmap, paramTypes, paramValues []byte, enc *inputDecoder) (err error) {
pos := 0
var (
tmp interface{}
v []byte
n int
isNull bool
)
if enc == nil {
enc = newInputDecoder(charset.CharsetUTF8)
}

for i := 0; i < len(args); i++ {
// if params had received via ComStmtSendLongData, use them directly.
Expand Down Expand Up @@ -543,6 +550,7 @@ func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams
}

if !isNull {
v = enc.decodeInput(v)
tmp = string(hack.String(v))
} else {
tmp = nil
Expand Down
15 changes: 14 additions & 1 deletion server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,25 @@ func TestParseExecArgs(t *testing.T) {
},
}
for _, tt := range tests {
err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues)
err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil)
require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err)
require.Equal(t, tt.expect, tt.args.args[0].GetValue())
}
}

func TestParseExecArgsAndEncode(t *testing.T) {
dt := make([]types.Datum, 1)
err := parseExecArgs(&stmtctx.StatementContext{},
dt,
[][]byte{nil},
[]byte{0x0},
[]byte{mysql.TypeVarchar, 0},
[]byte{4, 178, 226, 202, 212},
newInputDecoder("gbk"))
require.NoError(t, err)
require.Equal(t, "测试", dt[0].GetValue())
}

func TestParseStmtFetchCmd(t *testing.T) {
tests := []struct {
arg []byte
Expand Down
26 changes: 26 additions & 0 deletions server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,32 @@ func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resul
return buffer, nil
}

type inputDecoder struct {
encoding *charset.Encoding

buffer []byte
}

func newInputDecoder(chs string) *inputDecoder {
return &inputDecoder{
encoding: charset.NewEncoding(chs),
buffer: nil,
}
}

// clean prevents the inputDecoder from holding too much memory.
func (i *inputDecoder) clean() {
i.buffer = nil
}

func (i *inputDecoder) decodeInput(src []byte) []byte {
result, err := i.encoding.Decode(i.buffer, src)
if err != nil {
return src
}
return result
}

type resultEncoder struct {
// chsName and encoding are unchanged after the initialization from
// session variable @@character_set_results.
Expand Down

0 comments on commit 04a9618

Please sign in to comment.