diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index b0489775..c0358f64 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -963,3 +963,47 @@ func TestBeginStmt(t *testing.T) { require.Equal(t, test.isBegin, isBeginStmt(test.stmt)) } } + +// Test forwarding multi-statements works well. +func TestMultiStmt(t *testing.T) { + tc := newTCPConnSuite(t) + + // COM_STMT_PREPARE don't support multiple statements, so we only test COM_QUERY. + cfgs := []cfgOverrider{ + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComQuery + cfg.backendConfig.respondType = responseTypeOK + cfg.backendConfig.stmtNum = 2 + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComQuery + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.columns = 1 + cfg.backendConfig.stmtNum = 2 + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComQuery + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.columns = 1 + cfg.backendConfig.rows = 1 + cfg.backendConfig.stmtNum = 2 + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComQuery + cfg.backendConfig.respondType = responseTypeResultSet + cfg.backendConfig.columns = 1 + cfg.backendConfig.stmtNum = 3 + }, + func(cfg *testConfig) { + cfg.clientConfig.cmd = mysql.ComQuery + cfg.backendConfig.respondType = responseTypeLoadFile + cfg.backendConfig.stmtNum = 2 + }, + } + + for _, cfg := range cfgs { + ts, clean := newTestSuite(t, tc, cfg) + ts.executeCmd(t, nil) + clean() + } +} diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index cf6acea5..2c03b99a 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -38,6 +38,7 @@ type backendConfig struct { params int status uint16 loops int + stmtNum int } func newBackendConfig() *backendConfig { @@ -48,6 +49,7 @@ func newBackendConfig() *backendConfig { switchAuth: true, authSucceed: true, loops: 1, + stmtNum: 1, } } @@ -147,7 +149,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { } switch mb.respondType { case responseTypeOK: - return packetIO.WriteOKPacket(mb.status) + return mb.respondOK(packetIO) case responseTypeErr: return packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrUnknown)) case responseTypeResultSet: @@ -178,6 +180,21 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { return packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrUnknown)) } +func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error { + for i := 0; i < mb.stmtNum; i++ { + status := mb.status + if i < mb.stmtNum-1 { + status |= mysql.ServerMoreResultsExists + } else { + status &= ^mysql.ServerMoreResultsExists + } + if err := packetIO.WriteOKPacket(status); err != nil { + return err + } + } + return nil +} + // respond to FieldList func (mb *mockBackend) respondColumns(packetIO *pnet.PacketIO) error { for i := 0; i < mb.columns; i++ { @@ -216,53 +233,75 @@ func (mb *mockBackend) respondResultSet(packetIO *pnet.PacketIO) error { if err != nil { return err } - data := pnet.DumpLengthEncodedInt(nil, uint64(mb.columns)) - if err := packetIO.WritePacket(data, false); err != nil { - return err - } - for _, field := range rs.Fields { - if err := packetIO.WritePacket(field.Dump(), false); err != nil { + for i := 0; i < mb.stmtNum; i++ { + status := mb.status + if i < mb.stmtNum-1 { + status |= mysql.ServerMoreResultsExists + } else { + status &= ^mysql.ServerMoreResultsExists + } + data := pnet.DumpLengthEncodedInt(nil, uint64(mb.columns)) + if err := packetIO.WritePacket(data, false); err != nil { return err } - } - if err := packetIO.WriteEOFPacket(mb.status); err != nil { - return err - } - if mb.status&mysql.ServerStatusCursorExists == 0 { - for _, row := range values { - var data []byte - for _, value := range row { - data = pnet.DumpLengthEncodedString(data, []byte(value.(string))) + for _, field := range rs.Fields { + if err := packetIO.WritePacket(field.Dump(), false); err != nil { + return err } - if err := packetIO.WritePacket(data, false); err != nil { + } + if err := packetIO.WriteEOFPacket(status); err != nil { + return err + } + + if status&mysql.ServerStatusCursorExists == 0 { + for _, row := range values { + var data []byte + for _, value := range row { + data = pnet.DumpLengthEncodedString(data, []byte(value.(string))) + } + if err := packetIO.WritePacket(data, false); err != nil { + return err + } + } + if err := packetIO.WriteEOFPacket(status); err != nil { return err } } - return packetIO.WriteEOFPacket(mb.status) } return nil } // respond to LoadInFile func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error { - data := make([]byte, 0, 1+len(mockCmdStr)) - data = append(data, mysql.LocalInFileHeader) - data = append(data, []byte(mockCmdStr)...) - if err := packetIO.WritePacket(data, true); err != nil { - return err - } - for { - // read file data - pkt, err := packetIO.ReadPacket() - if err != nil { + for i := 0; i < mb.stmtNum; i++ { + status := mb.status + if i < mb.stmtNum-1 { + status |= mysql.ServerMoreResultsExists + } else { + status &= ^mysql.ServerMoreResultsExists + } + data := make([]byte, 0, 1+len(mockCmdStr)) + data = append(data, mysql.LocalInFileHeader) + data = append(data, []byte(mockCmdStr)...) + if err := packetIO.WritePacket(data, true); err != nil { return err } - // An empty packet indicates the end of file. - if len(pkt) == 0 { - break + for { + // read file data + pkt, err := packetIO.ReadPacket() + if err != nil { + return err + } + // An empty packet indicates the end of file. + if len(pkt) == 0 { + break + } + } + if err := packetIO.WriteOKPacket(status); err != nil { + return err } } - return packetIO.WriteOKPacket(mb.status) + return nil } // respond to Prepare diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index 853f605c..3289f589 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -281,40 +281,53 @@ func (mc *mockClient) query(packetIO *pnet.PacketIO) error { } func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error { - pkt, err := packetIO.ReadPacket() - if err != nil { - return err - } - switch pkt[0] { - case mysql.OKHeader: - // check status - case mysql.ErrHeader: - return nil - case mysql.LocalInFileHeader: - for i := 0; i < mc.filePkts; i++ { - if err = packetIO.WritePacket(mockCmdBytes, false); err != nil { - return err - } - } - if err = packetIO.WritePacket(nil, true); err != nil { - return err - } - if _, err = packetIO.ReadPacket(); err != nil { + for { + var serverStatus uint16 + pkt, err := packetIO.ReadPacket() + if err != nil { return err } - default: - // read result set - for { + switch pkt[0] { + case mysql.OKHeader: + serverStatus = binary.LittleEndian.Uint16(pkt[3:]) + case mysql.ErrHeader: + return nil + case mysql.LocalInFileHeader: + for i := 0; i < mc.filePkts; i++ { + if err = packetIO.WritePacket(mockCmdBytes, false); err != nil { + return err + } + } + if err = packetIO.WritePacket(nil, true); err != nil { + return err + } if pkt, err = packetIO.ReadPacket(); err != nil { return err } - if pnet.IsEOFPacket(pkt) { - break + if pkt[0] == mysql.OKHeader { + serverStatus = binary.LittleEndian.Uint16(pkt[3:]) + } else { + return nil + } + default: + // read result set + for { + if pkt, err = packetIO.ReadPacket(); err != nil { + return err + } + if pnet.IsEOFPacket(pkt) { + break + } + } + serverStatus = binary.LittleEndian.Uint16(pkt[3:]) + if serverStatus&mysql.ServerStatusCursorExists == 0 { + if err = mc.readErrOrUntilEOF(packetIO); err != nil { + return err + } } } - serverStatus := binary.LittleEndian.Uint16(pkt[3:]) - if serverStatus&mysql.ServerStatusCursorExists == 0 { - return mc.readErrOrUntilEOF(packetIO) + if serverStatus&mysql.ServerMoreResultsExists == 0 { + break } } return nil