From 898d46a7a47b06dae34e33714c791a60f8689882 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Tue, 9 Aug 2022 17:12:12 +0800 Subject: [PATCH] backend: add tests for prepared statements (#33) --- pkg/proxy/backend/authenticator_test.go | 4 +- pkg/proxy/backend/cmd_processor_exec.go | 20 +- pkg/proxy/backend/cmd_processor_query.go | 2 + pkg/proxy/backend/cmd_processor_test.go | 559 +++++++++++++++++++++-- pkg/proxy/backend/mock_backend_test.go | 87 +++- pkg/proxy/backend/mock_client_test.go | 141 +++++- pkg/proxy/backend/mock_proxy_test.go | 11 + pkg/proxy/backend/testsuite_test.go | 82 +++- pkg/proxy/net/packetio.go | 5 + pkg/proxy/net/packetio_mysql.go | 8 +- 10 files changed, 811 insertions(+), 108 deletions(-) diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 66ff1a53..5b7bb1dd 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -46,9 +46,9 @@ func TestTLSConnection(t *testing.T) { cfgOverriders := getCfgCombinations(cfgs) for _, cfgs := range cfgOverriders { ts, clean := newTestSuite(t, tc, cfgs...) - ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite, _, _, perr error) { + ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) { if ts.mb.backendConfig.capability&mysql.ClientSSL == 0 { - require.ErrorContains(t, perr, "must enable TLS") + require.ErrorContains(t, ts.mp.err, "must enable TLS") } }) clean() diff --git a/pkg/proxy/backend/cmd_processor_exec.go b/pkg/proxy/backend/cmd_processor_exec.go index 149da7a8..b3c8d721 100644 --- a/pkg/proxy/backend/cmd_processor_exec.go +++ b/pkg/proxy/backend/cmd_processor_exec.go @@ -62,6 +62,8 @@ func (cp *CmdProcessor) forwardCommand(clientIO, backendIO *pnet.PacketIO, reque return cp.forwardQueryCmd(clientIO, backendIO, request) case mysql.ComStmtClose: return cp.forwardCloseCmd(request) + case mysql.ComStmtSendLongData: + return cp.forwardSendLongDataCmd(request) case mysql.ComChangeUser: return cp.forwardChangeUserCmd(clientIO, backendIO, request) case mysql.ComStatistics: @@ -131,7 +133,7 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) (s succeed = true } for i := 0; i < expectedEOFNum; i++ { - // The server status in EOF packets is always 0, so ignore it. + // Ignore this status because PREPARE doesn't affect status. if _, err = forwardUntilEOF(clientIO, backendIO); err != nil { return } @@ -164,12 +166,8 @@ func (cp *CmdProcessor) forwardQueryCmd(clientIO, backendIO *pnet.PacketIO, requ var serverStatus uint16 switch response[0] { case mysql.OKHeader: - if err = clientIO.Flush(); err != nil { - return false, err - } rs := cp.handleOKPacket(request, response) - serverStatus = rs.Status - succeed = true + serverStatus, succeed, err = rs.Status, true, clientIO.Flush() case mysql.ErrHeader: // Subsequent statements won't be executed even if it's a multi-statement. return false, clientIO.Flush() @@ -205,8 +203,7 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re } } var response []byte - response, err = forwardOnePacket(clientIO, backendIO, true) - if err != nil { + if response, err = forwardOnePacket(clientIO, backendIO, true); err != nil { return } if response[0] == mysql.OKHeader { @@ -233,6 +230,7 @@ func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, req if response, err = forwardOnePacket(clientIO, backendIO, false); err != nil { return } + // An error may occur when the backend writes rows. if response[0] == mysql.ErrHeader { return 0, false, clientIO.Flush() } @@ -251,6 +249,12 @@ func (cp *CmdProcessor) forwardCloseCmd(request []byte) (succeed bool, err error return true, nil } +func (cp *CmdProcessor) forwardSendLongDataCmd(request []byte) (succeed bool, err error) { + // No packet is sent to the client for COM_STMT_SEND_LONG_DATA. + cp.updatePrepStmtStatus(request, 0) + return true, nil +} + func (cp *CmdProcessor) forwardChangeUserCmd(clientIO, backendIO *pnet.PacketIO, request []byte) (succeed bool, err error) { // Currently, TiDB responses with an OK or Err packet. But according to the MySQL doc, the server may send a // switch auth request. diff --git a/pkg/proxy/backend/cmd_processor_query.go b/pkg/proxy/backend/cmd_processor_query.go index f1a2658f..2860cee0 100644 --- a/pkg/proxy/backend/cmd_processor_query.go +++ b/pkg/proxy/backend/cmd_processor_query.go @@ -57,6 +57,7 @@ func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomy return } +// readResultSet is only used for reading the results of `show session_states` currently. func (cp *CmdProcessor) readResultSet(packetIO *pnet.PacketIO, data []byte) (*gomysql.Result, error) { columnCount, _, n := pnet.ParseLengthEncodedInt(data) if n-len(data) != 0 { @@ -114,6 +115,7 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql. result.Status = binary.LittleEndian.Uint16(data[3:]) break } + // An error may occur when the backend writes rows. if data[0] == mysql.ErrHeader { return cp.handleErrorPacket(data) } diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index 3e04eaf3..cf98367f 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/pingcap/tidb/parser/mysql" + "github.com/stretchr/testify/require" ) type respondType int @@ -32,51 +33,53 @@ const ( responseTypeString responseTypeEOF responseTypeSwitchRequest + responseTypePrepareOK + responseTypeRow ) -// cmdResponseTypes lists simple commands and their responses. +// cmdResponseTypes lists all commands and their responses. var cmdResponseTypes = map[byte][]respondType{ - mysql.ComSleep: {responseTypeErr}, - mysql.ComQuit: {responseTypeOK, responseTypeNone}, - mysql.ComInitDB: {responseTypeOK, responseTypeErr}, - mysql.ComQuery: {responseTypeOK, responseTypeErr, responseTypeResultSet, responseTypeLoadFile}, - mysql.ComFieldList: {responseTypeErr, responseTypeColumn}, - mysql.ComCreateDB: {responseTypeOK, responseTypeErr}, - mysql.ComDropDB: {responseTypeOK, responseTypeErr}, - mysql.ComRefresh: {responseTypeOK, responseTypeErr}, - mysql.ComShutdown: {responseTypeOK, responseTypeErr}, - mysql.ComStatistics: {responseTypeString}, - mysql.ComProcessInfo: {responseTypeErr, responseTypeResultSet}, - mysql.ComConnect: {responseTypeErr}, - mysql.ComProcessKill: {responseTypeOK, responseTypeErr}, - mysql.ComDebug: {responseTypeEOF, responseTypeErr}, - mysql.ComPing: {responseTypeOK}, - mysql.ComTime: {responseTypeErr}, - mysql.ComDelayedInsert: {responseTypeErr}, - mysql.ComChangeUser: {responseTypeSwitchRequest, responseTypeOK, responseTypeErr}, - mysql.ComBinlogDump: {responseTypeErr}, - mysql.ComTableDump: {responseTypeErr}, - mysql.ComConnectOut: {responseTypeErr}, - mysql.ComRegisterSlave: {responseTypeErr}, - //mysql.ComStmtPrepare: {responseTypeOK, responseTypeErr}, - //mysql.ComStmtExecute: {responseTypeOK, responseTypeErr}, - //mysql.ComStmtSendLongData: {responseTypeOK, responseTypeErr}, - //mysql.ComStmtClose: {responseTypeOK, responseTypeErr}, - //mysql.ComStmtReset: {responseTypeOK, responseTypeErr}, - mysql.ComSetOption: {responseTypeEOF, responseTypeErr}, - //mysql.ComStmtFetch: {responseTypeOK, responseTypeErr}, - mysql.ComDaemon: {responseTypeErr}, - mysql.ComBinlogDumpGtid: {responseTypeErr}, - mysql.ComResetConnection: {responseTypeOK, responseTypeErr}, - mysql.ComEnd: {responseTypeErr}, + mysql.ComSleep: {responseTypeErr}, + mysql.ComQuit: {responseTypeNone}, + mysql.ComInitDB: {responseTypeOK, responseTypeErr}, + mysql.ComQuery: {responseTypeOK, responseTypeErr, responseTypeResultSet, responseTypeLoadFile}, + mysql.ComFieldList: {responseTypeErr, responseTypeColumn}, + mysql.ComCreateDB: {responseTypeOK, responseTypeErr}, + mysql.ComDropDB: {responseTypeOK, responseTypeErr}, + mysql.ComRefresh: {responseTypeOK, responseTypeErr}, + mysql.ComShutdown: {responseTypeOK, responseTypeErr}, + mysql.ComStatistics: {responseTypeString}, + mysql.ComProcessInfo: {responseTypeErr, responseTypeResultSet}, + mysql.ComConnect: {responseTypeErr}, + mysql.ComProcessKill: {responseTypeOK, responseTypeErr}, + mysql.ComDebug: {responseTypeEOF, responseTypeErr}, + mysql.ComPing: {responseTypeOK}, + mysql.ComTime: {responseTypeErr}, + mysql.ComDelayedInsert: {responseTypeErr}, + mysql.ComChangeUser: {responseTypeSwitchRequest, responseTypeOK, responseTypeErr}, + mysql.ComBinlogDump: {responseTypeErr}, + mysql.ComTableDump: {responseTypeErr}, + mysql.ComConnectOut: {responseTypeErr}, + mysql.ComRegisterSlave: {responseTypeErr}, + mysql.ComStmtPrepare: {responseTypePrepareOK, responseTypeErr}, + mysql.ComStmtExecute: {responseTypeOK, responseTypeErr, responseTypeResultSet}, + mysql.ComStmtSendLongData: {responseTypeNone}, + mysql.ComStmtClose: {responseTypeNone}, + mysql.ComStmtReset: {responseTypeOK, responseTypeErr}, + mysql.ComSetOption: {responseTypeEOF, responseTypeErr}, + mysql.ComStmtFetch: {responseTypeRow, responseTypeErr}, + mysql.ComDaemon: {responseTypeErr}, + mysql.ComBinlogDumpGtid: {responseTypeErr}, + mysql.ComResetConnection: {responseTypeOK, responseTypeErr}, + mysql.ComEnd: {responseTypeErr}, } -func TestSimpleCommands(t *testing.T) { +// Test forwarding packets between the client and the backend. +func TestForwardCommands(t *testing.T) { tc := newTCPConnSuite(t) runTest := func(cfgs ...cfgOverrider) { ts, clean := newTestSuite(t, tc, cfgs...) - // Only verify that it won't hang or report errors. - ts.executeCmd(t) + ts.executeCmd(t, nil) clean() } // Test every respond type for every command. @@ -89,14 +92,31 @@ func TestSimpleCommands(t *testing.T) { // Test more variables for some special response types. switch respondType { case responseTypeColumn: - for _, columns := range []int{1, 3} { + for _, columns := range []int{1, 4096} { extraCfgOvr := func(cfg *testConfig) { cfg.backendConfig.columns = columns } runTest(cfgOvr, extraCfgOvr) } + case responseTypeRow: + for _, rows := range []int{0, 1, 3} { + extraCfgOvr := func(cfg *testConfig) { + cfg.backendConfig.rows = rows + } + runTest(cfgOvr, extraCfgOvr) + } + case responseTypePrepareOK: + for _, columns := range []int{0, 1, 4096} { + for _, params := range []int{0, 1, 3} { + extraCfgOvr := func(cfg *testConfig) { + cfg.backendConfig.columns = columns + cfg.backendConfig.params = params + } + runTest(cfgOvr, extraCfgOvr) + } + } case responseTypeResultSet: - for _, columns := range []int{1, 3} { + for _, columns := range []int{1, 4096} { for _, rows := range []int{0, 1, 3} { extraCfgOvr := func(cfg *testConfig) { cfg.backendConfig.columns = columns @@ -118,3 +138,464 @@ func TestSimpleCommands(t *testing.T) { } } } + +// Test querying directly from the server. +func TestDirectQuery(t *testing.T) { + tc := newTCPConnSuite(t) + tests := []struct { + cfg cfgOverrider + c checker + }{ + { + cfg: func(cfg *testConfig) { + cfg.backendConfig.columns = 2 + cfg.backendConfig.rows = 1 + cfg.backendConfig.respondType = responseTypeResultSet + }, + }, + { + cfg: func(cfg *testConfig) { + cfg.backendConfig.respondType = responseTypeErr + }, + c: func(t *testing.T, ts *testSuite) { + require.Error(t, ts.mp.err) + require.NoError(t, ts.mb.err) + }, + }, + { + cfg: func(cfg *testConfig) { + cfg.backendConfig.respondType = responseTypeOK + }, + }, + } + for _, test := range tests { + ts, clean := newTestSuite(t, tc, test.cfg) + ts.query(t, test.c) + clean() + } +} + +func TestPreparedStmts(t *testing.T) { + tc := newTCPConnSuite(t) + tests := []struct { + cfgs []cfgOverrider + canRedirect bool + }{ + // prepare + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtPrepare + cfg.backendConfig.respondType = responseTypePrepareOK + }, + }, + canRedirect: true, + }, + // send long data + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + }, + canRedirect: false, + }, + // send long data and execute + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeErr + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + }, + canRedirect: false, + }, + // execute and fetch + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtFetch + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeRow + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtFetch + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeRow + cfg.backendConfig.status = mysql.ServerStatusLastRowSend + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtFetch + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeErr + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtFetch + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeRow + cfg.backendConfig.status = mysql.ServerStatusLastRowSend + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtFetch + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeRow + cfg.backendConfig.status = mysql.ServerStatusLastRowSend + }, + }, + canRedirect: false, + }, + // send long data and close/reset + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtClose + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtClose + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeNone + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtReset + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtReset + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: false, + }, + // execute and close/reset + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtClose + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtReset + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtClose + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeNone + }, + }, + canRedirect: false, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtReset + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: false, + }, + // reset connection and change user + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComResetConnection + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: true, + }, + { + cfgs: []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtSendLongData + cfg.clientConfig.prepStmtID = 1 + cfg.backendConfig.respondType = responseTypeNone + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComStmtExecute + cfg.clientConfig.prepStmtID = 2 + cfg.backendConfig.columns = 1 + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.status = mysql.ServerStatusCursorExists + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComChangeUser + cfg.backendConfig.respondType = responseTypeOK + }, + }, + canRedirect: true, + }, + } + + for _, test := range tests { + ts, clean := newTestSuite(t, tc) + c := func(t *testing.T, ts *testSuite) { + require.Equal(t, test.canRedirect, ts.mp.cmdProcessor.canRedirect()) + } + ts.executeMultiCmd(t, test.cfgs, c) + clean() + } +} diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index b151d5fe..2b58cb28 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -35,6 +35,8 @@ type backendConfig struct { respondType respondType columns int rows int + params int + status uint16 } type mockBackend struct { @@ -45,6 +47,7 @@ type mockBackend struct { authData []byte db string attrs []byte + err error } func newMockBackend(cfg *backendConfig) *mockBackend { @@ -105,7 +108,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO) error { } } if mb.authSucceed { - if err := packetIO.WriteOKPacket(); err != nil { + if err := packetIO.WriteOKPacket(mb.status); err != nil { return err } } else { @@ -117,12 +120,13 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO) error { } func (mb *mockBackend) respond(packetIO *pnet.PacketIO) error { + packetIO.ResetSequence() if _, err := packetIO.ReadPacket(); err != nil { return err } switch mb.respondType { case responseTypeOK: - return packetIO.WriteOKPacket() + return packetIO.WriteOKPacket(mb.status) case responseTypeErr: return packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrUnknown)) case responseTypeResultSet: @@ -134,7 +138,7 @@ func (mb *mockBackend) respond(packetIO *pnet.PacketIO) error { case responseTypeString: return packetIO.WritePacket([]byte(mockCmdStr), true) case responseTypeEOF: - return packetIO.WriteEOFPacket() + return packetIO.WriteEOFPacket(mb.status) case responseTypeSwitchRequest: if err := packetIO.WriteSwitchRequest(mb.authPlugin, mb.salt); err != nil { return err @@ -142,7 +146,11 @@ func (mb *mockBackend) respond(packetIO *pnet.PacketIO) error { if _, err := packetIO.ReadPacket(); err != nil { return err } - return packetIO.WriteOKPacket() + return packetIO.WriteOKPacket(mb.status) + case responseTypePrepareOK: + return mb.respondPrepare(packetIO) + case responseTypeRow: + return mb.respondRows(packetIO) case responseTypeNone: return nil } @@ -152,11 +160,21 @@ func (mb *mockBackend) respond(packetIO *pnet.PacketIO) error { // respond to FieldList func (mb *mockBackend) respondColumns(packetIO *pnet.PacketIO) error { for i := 0; i < mb.columns; i++ { - if err := packetIO.WritePacket(mockCmdBytes, true); err != nil { + if err := packetIO.WritePacket(mockCmdBytes, false); err != nil { return err } } - return packetIO.WriteEOFPacket() + return packetIO.WriteEOFPacket(mb.status) +} + +// respond to Fetch +func (mb *mockBackend) respondRows(packetIO *pnet.PacketIO) error { + for i := 0; i < mb.rows; i++ { + if err := packetIO.WritePacket(mockCmdBytes, false); err != nil { + return err + } + } + return packetIO.WriteEOFPacket(mb.status) } // respond to Query @@ -186,19 +204,22 @@ func (mb *mockBackend) respondResultSet(packetIO *pnet.PacketIO) error { return err } } - if err := packetIO.WriteEOFPacket(); err != nil { + if err := packetIO.WriteEOFPacket(mb.status); err != nil { return err } - for _, row := range values { - var data []byte - for _, value := range row { - data = pnet.DumpLengthEncodedString(data, []byte(value.(string))) - } - if err := packetIO.WritePacket(data, false); err != nil { - return err + if mb.status&mysql.ServerStatusCursorExists == 0 { + for _, row := range values { + var data []byte + for _, value := range row { + data = pnet.DumpLengthEncodedString(data, []byte(value.(string))) + } + if err := packetIO.WritePacket(data, false); err != nil { + return err + } } + return packetIO.WriteEOFPacket(mb.status) } - return packetIO.WriteEOFPacket() + return nil } // respond to LoadInFile @@ -220,5 +241,39 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error { break } } - return packetIO.WriteOKPacket() + return packetIO.WriteOKPacket(mb.status) +} + +// respond to Prepare +func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error { + data := []byte{mysql.OKHeader} + data = pnet.DumpUint32(data, uint32(mockCmdInt)) + data = pnet.DumpUint16(data, uint16(mb.columns)) + data = pnet.DumpUint16(data, uint16(mb.params)) + data = append(data, 0x00) + data = pnet.DumpUint16(data, uint16(mockCmdInt)) + if err := packetIO.WritePacket(data, true); err != nil { + return err + } + if mb.params > 0 { + for i := 0; i < mb.params; i++ { + if err := packetIO.WritePacket(mockCmdBytes, false); err != nil { + return err + } + } + if err := packetIO.WriteEOFPacket(mb.status); err != nil { + return err + } + } + if mb.columns > 0 { + for i := 0; i < mb.columns; i++ { + if err := packetIO.WritePacket(mockCmdBytes, false); err != nil { + return err + } + } + if err := packetIO.WriteEOFPacket(mb.status); err != nil { + return err + } + } + return nil } diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index ffc7f5c0..142185d1 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -16,6 +16,7 @@ package backend import ( "crypto/tls" + "encoding/binary" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" @@ -32,8 +33,9 @@ type clientConfig struct { authData []byte attrs []byte // for cmd - cmd byte - filePkts int + cmd byte + filePkts int + prepStmtID int } type mockClient struct { @@ -41,6 +43,7 @@ type mockClient struct { *clientConfig // Outputs that received from the server and will be checked by the test. authSucceed bool + err error } func newMockClient(cfg *clientConfig) *mockClient { @@ -98,28 +101,40 @@ func (mc *mockClient) writePassword(packetIO *pnet.PacketIO) error { // request sends commands except prepared statements commands. func (mc *mockClient) request(packetIO *pnet.PacketIO) error { + packetIO.ResetSequence() data := []byte{mc.cmd} switch mc.cmd { case mysql.ComInitDB, mysql.ComCreateDB, mysql.ComDropDB: data = append(data, []byte(mockCmdStr)...) case mysql.ComQuery: return mc.query(packetIO) + case mysql.ComProcessInfo: + return mc.requestProcessInfo(packetIO) case mysql.ComFieldList: - data = append(data, []byte(mockCmdStr)...) - data = append(data, 0x00) - data = append(data, []byte(mockCmdStr)...) + return mc.requestFieldList(packetIO) case mysql.ComRefresh, mysql.ComSetOption: data = append(data, mockCmdByte) case mysql.ComProcessKill: - data = append(data, byte(mockCmdInt), byte(mockCmdInt>>8), byte(mockCmdInt>>16), byte(mockCmdInt>>24)) + data = pnet.DumpUint32(data, uint32(mockCmdInt)) case mysql.ComChangeUser: return mc.requestChangeUser(packetIO) + case mysql.ComStmtPrepare: + return mc.requestPrepare(packetIO) + case mysql.ComStmtSendLongData: + data = pnet.DumpUint32(data, uint32(mc.prepStmtID)) + data = append(data, mockCmdBytes...) + case mysql.ComStmtExecute: + return mc.requestExecute(packetIO) + case mysql.ComStmtFetch: + return mc.requestFetch(packetIO) + case mysql.ComStmtClose, mysql.ComStmtReset: + data = pnet.DumpUint32(data, uint32(mc.prepStmtID)) } if err := packetIO.WritePacket(data, true); err != nil { return err } switch mc.cmd { - case mysql.ComQuit: + case mysql.ComQuit, mysql.ComStmtClose, mysql.ComStmtSendLongData: return nil } _, err := packetIO.ReadPacket() @@ -147,6 +162,101 @@ func (mc *mockClient) requestChangeUser(packetIO *pnet.PacketIO) error { } } +func (mc *mockClient) requestPrepare(packetIO *pnet.PacketIO) error { + data := make([]byte, 0, len(mockCmdStr)+1) + data = append(data, mysql.ComStmtPrepare) + data = append(data, []byte(mockCmdStr)...) + if err := packetIO.WritePacket(data, true); err != nil { + return err + } + response, err := packetIO.ReadPacket() + if err != nil { + return err + } + expectedEOFNum := 0 + if response[0] == mysql.OKHeader { + numColumns := binary.LittleEndian.Uint16(response[5:]) + if numColumns > 0 { + expectedEOFNum++ + } + numParams := binary.LittleEndian.Uint16(response[7:]) + if numParams > 0 { + expectedEOFNum++ + } + } + for i := 0; i < expectedEOFNum; i++ { + for { + if response, err = packetIO.ReadPacket(); err != nil { + return err + } + if pnet.IsEOFPacket(response) { + break + } + } + } + return nil +} + +func (mc *mockClient) requestExecute(packetIO *pnet.PacketIO) error { + data := make([]byte, 0, len(mockCmdBytes)+5) + data = append(data, mysql.ComStmtExecute) + data = pnet.DumpUint32(data, uint32(mc.prepStmtID)) + data = append(data, mockCmdBytes...) + if err := packetIO.WritePacket(data, true); err != nil { + return err + } + return mc.readResultSet(packetIO) +} + +func (mc *mockClient) requestFetch(packetIO *pnet.PacketIO) error { + data := make([]byte, 0, len(mockCmdBytes)+5) + data = append(data, mysql.ComStmtFetch) + data = pnet.DumpUint32(data, uint32(mc.prepStmtID)) + data = append(data, mockCmdBytes...) + if err := packetIO.WritePacket(data, true); err != nil { + return err + } + return mc.readErrOrUntilEOF(packetIO) +} + +func (mc *mockClient) requestFieldList(packetIO *pnet.PacketIO) error { + data := make([]byte, 0, len(mockCmdStr)+2) + data = append(data, mysql.ComFieldList) + data = append(data, []byte(mockCmdStr)...) + data = append(data, 0x00) + data = append(data, []byte(mockCmdStr)...) + if err := packetIO.WritePacket(data, true); err != nil { + return err + } + return mc.readErrOrUntilEOF(packetIO) +} + +func (mc *mockClient) readErrOrUntilEOF(packetIO *pnet.PacketIO) error { + pkt, err := packetIO.ReadPacket() + if err != nil { + return err + } + if pkt[0] == mysql.ErrHeader || pnet.IsEOFPacket(pkt) { + return nil + } + for { + if pkt, err = packetIO.ReadPacket(); err != nil { + return err + } + if pnet.IsEOFPacket(pkt) { + break + } + } + return nil +} + +func (mc *mockClient) requestProcessInfo(packetIO *pnet.PacketIO) error { + if err := packetIO.WritePacket([]byte{mysql.ComProcessInfo}, true); err != nil { + return err + } + return mc.readResultSet(packetIO) +} + func (mc *mockClient) query(packetIO *pnet.PacketIO) error { data := make([]byte, 0, len(mockCmdStr)+1) data = append(data, mysql.ComQuery) @@ -154,6 +264,10 @@ func (mc *mockClient) query(packetIO *pnet.PacketIO) error { if err := packetIO.WritePacket(data, true); err != nil { return err } + return mc.readResultSet(packetIO) +} + +func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error { pkt, err := packetIO.ReadPacket() if err != nil { return err @@ -185,16 +299,9 @@ func (mc *mockClient) query(packetIO *pnet.PacketIO) error { break } } - for { - if pkt, err = packetIO.ReadPacket(); err != nil { - return err - } - if pkt[0] == mysql.ErrHeader { - return nil - } - if pnet.IsEOFPacket(pkt) { - break - } + serverStatus := binary.LittleEndian.Uint16(pkt[3:]) + if serverStatus&mysql.ServerStatusCursorExists == 0 { + return mc.readErrOrUntilEOF(packetIO) } } return nil diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 57e2e9c1..bd13a13c 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -17,6 +17,7 @@ package backend import ( "crypto/tls" + gomysql "github.com/go-mysql-org/go-mysql/mysql" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" ) @@ -28,8 +29,11 @@ type proxyConfig struct { type mockProxy struct { *proxyConfig + err error auth *Authenticator cmdProcessor *CmdProcessor + // outputs that received from the server. + rs *gomysql.Result } func newMockProxy(cfg *proxyConfig) *mockProxy { @@ -50,6 +54,7 @@ func (mp *mockProxy) authenticateSecondTime(_, backendIO *pnet.PacketIO) error { } func (mp *mockProxy) processCmd(clientIO, backendIO *pnet.PacketIO) error { + clientIO.ResetSequence() request, err := clientIO.ReadPacket() if err != nil { return err @@ -57,3 +62,9 @@ func (mp *mockProxy) processCmd(clientIO, backendIO *pnet.PacketIO) error { _, _, err = mp.cmdProcessor.executeCmd(request, clientIO, backendIO, false) return err } + +func (mp *mockProxy) directQuery(_, backendIO *pnet.PacketIO) error { + rs, _, err := mp.cmdProcessor.query(backendIO, mockCmdStr) + mp.rs = rs + return err +} diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index d38ae55a..f0dd978a 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -15,9 +15,11 @@ package backend import ( + "fmt" "strings" "testing" + pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -120,7 +122,7 @@ type testSuite struct { mc *mockClient } -type errChecker func(t *testing.T, ts *testSuite, cerr, berr, perr error) +type checker func(t *testing.T, ts *testSuite) func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (*testSuite, func()) { ts := &testSuite{} @@ -138,6 +140,13 @@ func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (* return ts, clean } +func (ts *testSuite) setConfig(overriders ...cfgOverrider) { + cfg := newTestConfig(overriders...) + ts.mb.backendConfig = &cfg.backendConfig + ts.mp.proxyConfig = &cfg.proxyConfig + ts.mc.clientConfig = &cfg.clientConfig +} + func (ts *testSuite) changeDB(db string) { ts.mc.dbName = db ts.mp.auth.updateCurrentDB(db) @@ -149,47 +158,76 @@ func (ts *testSuite) changeUser(username, db string) { ts.mp.auth.changeUser(username, db) } +func (ts *testSuite) runAndCheck(t *testing.T, c checker, clientRunner, backendRunner func(*pnet.PacketIO) error, + proxyRunner func(*pnet.PacketIO, *pnet.PacketIO) error) { + ts.mc.err, ts.mb.err, ts.mp.err = ts.tc.run(t, clientRunner, backendRunner, proxyRunner) + if c == nil { + require.NoError(t, ts.mc.err) + require.NoError(t, ts.mb.err) + require.NoError(t, ts.mp.err) + if clientRunner != nil && backendRunner != nil { + // Ensure all the packets are forwarded. + msg := fmt.Sprintf("cmd:%d responseType:%d", ts.mc.cmd, ts.mb.respondType) + require.Equal(t, ts.tc.backendIO.GetSequence(), ts.tc.clientIO.GetSequence(), msg) + } + } else { + c(t, ts) + } +} + // The client connects to the backend through the proxy. -func (ts *testSuite) authenticateFirstTime(t *testing.T, ce errChecker) { - cerr, berr, perr := ts.tc.run(t, ts.mc.authenticate, ts.mb.authenticate, ts.mp.authenticateFirstTime) - if ce == nil { - require.NoError(t, berr) - require.NoError(t, cerr) - require.NoError(t, perr) +func (ts *testSuite) authenticateFirstTime(t *testing.T, c checker) { + ts.runAndCheck(t, c, ts.mc.authenticate, ts.mb.authenticate, ts.mp.authenticateFirstTime) + if c == nil { // Check the data received by client equals to the data sent from the server and vice versa. require.Equal(t, ts.mb.authSucceed, ts.mc.authSucceed) require.Equal(t, ts.mc.username, ts.mb.username) require.Equal(t, ts.mc.dbName, ts.mb.db) require.Equal(t, ts.mc.authData, ts.mb.authData) require.Equal(t, ts.mc.attrs, ts.mb.attrs) - } else { - ce(t, ts, cerr, berr, perr) } } // The proxy reconnects to the proxy using preserved client data. // This must be called after authenticateFirstTime. -func (ts *testSuite) authenticateSecondTime(t *testing.T, ce errChecker) { +func (ts *testSuite) authenticateSecondTime(t *testing.T, c checker) { // The server won't request switching auth-plugin this time. ts.mb.backendConfig.switchAuth = false ts.mb.backendConfig.authSucceed = true - cerr, berr, perr := ts.tc.run(t, nil, ts.mb.authenticate, ts.mp.authenticateSecondTime) - if ce == nil { - require.NoError(t, berr) - require.NoError(t, cerr) - require.NoError(t, perr) + ts.runAndCheck(t, c, nil, ts.mb.authenticate, ts.mp.authenticateSecondTime) + if c == nil { require.Equal(t, ts.mc.username, ts.mb.username) require.Equal(t, ts.mc.dbName, ts.mb.db) require.Equal(t, []byte(ts.mp.sessionToken), ts.mb.authData) - } else { - ce(t, ts, cerr, berr, perr) } } // Test forwarding commands between the client and the server. -func (ts *testSuite) executeCmd(t *testing.T) { - cerr, berr, perr := ts.tc.run(t, ts.mc.request, ts.mb.respond, ts.mp.processCmd) - require.NoError(t, berr) - require.NoError(t, cerr) - require.NoError(t, perr) +// It verifies that it won't hang or report errors, and all the packets are forwarded. +func (ts *testSuite) executeCmd(t *testing.T, c checker) { + ts.runAndCheck(t, c, ts.mc.request, ts.mb.respond, ts.mp.processCmd) +} + +// Execute multiple commands at once to reuse the same ComProcessor. +func (ts *testSuite) executeMultiCmd(t *testing.T, cfgs []cfgOverrider, c checker) { + for _, cfg := range cfgs { + ts.setConfig(cfg) + ts.runAndCheck(t, nil, ts.mc.request, ts.mb.respond, ts.mp.processCmd) + } + // Only check it at last. + if c != nil { + c(t, ts) + } +} + +// Test querying from the backend directly. +// It verifies that it won't hang or panic, and column / row counts match. +func (ts *testSuite) query(t *testing.T, c checker) { + ts.runAndCheck(t, c, nil, ts.mb.respond, ts.mp.directQuery) + if c == nil { + if ts.mb.respondType == responseTypeResultSet { + require.Equal(t, ts.mb.columns, len(ts.mp.rs.Fields)) + require.Equal(t, ts.mb.rows, len(ts.mp.rs.RowDatas)) + } + } } diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 19baf7d2..98590355 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -119,6 +119,11 @@ func (p *PacketIO) ResetSequence() { p.sequence = 0 } +// GetSequence is used in tests to assert that the sequences on the client and server are equal. +func (p *PacketIO) GetSequence() uint8 { + return p.sequence +} + func (p *PacketIO) ReadOnePacket() ([]byte, bool, error) { var header [4]byte diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index 79bec001..22634fbb 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -106,22 +106,22 @@ func (p *PacketIO) WriteErrPacket(merr *mysql.SQLError) error { } // WriteOKPacket writes an OK packet. It's only for testing. -func (p *PacketIO) WriteOKPacket() error { +func (p *PacketIO) WriteOKPacket(status uint16) error { data := make([]byte, 0, 7) data = append(data, mysql.OKHeader) data = append(data, 0, 0) // It's compatible with both ClientProtocol41 enabled and disabled. - data = DumpUint16(data, mysql.ServerStatusAutocommit) + data = DumpUint16(data, status) data = append(data, 0, 0) return p.WritePacket(data, true) } // WriteEOFPacket writes an EOF packet. It's only for testing. -func (p *PacketIO) WriteEOFPacket() error { +func (p *PacketIO) WriteEOFPacket(status uint16) error { data := make([]byte, 0, 5) data = append(data, mysql.EOFHeader) data = append(data, 0, 0) // It's compatible with both ClientProtocol41 enabled and disabled. - data = DumpUint16(data, mysql.ServerStatusAutocommit) + data = DumpUint16(data, status) return p.WritePacket(data, true) }