Skip to content

Commit

Permalink
backend: add tests for forwarding multi-statements (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Aug 11, 2022
1 parent 5bd467f commit 930663c
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 59 deletions.
44 changes: 44 additions & 0 deletions pkg/proxy/backend/cmd_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,47 @@ func TestBeginStmt(t *testing.T) {
require.Equal(t, test.isBegin, isBeginStmt(test.stmt))
}
}

// Test forwarding multi-statements works well.
func TestMultiStmt(t *testing.T) {
tc := newTCPConnSuite(t)

// COM_STMT_PREPARE don't support multiple statements, so we only test COM_QUERY.
cfgs := []cfgOverrider{
func(cfg *testConfig) {
cfg.clientConfig.cmd = mysql.ComQuery
cfg.backendConfig.respondType = responseTypeOK
cfg.backendConfig.stmtNum = 2
},
func(cfg *testConfig) {
cfg.clientConfig.cmd = mysql.ComQuery
cfg.backendConfig.respondType = responseTypeResultSet
cfg.backendConfig.columns = 1
cfg.backendConfig.stmtNum = 2
},
func(cfg *testConfig) {
cfg.clientConfig.cmd = mysql.ComQuery
cfg.backendConfig.respondType = responseTypeResultSet
cfg.backendConfig.columns = 1
cfg.backendConfig.rows = 1
cfg.backendConfig.stmtNum = 2
},
func(cfg *testConfig) {
cfg.clientConfig.cmd = mysql.ComQuery
cfg.backendConfig.respondType = responseTypeResultSet
cfg.backendConfig.columns = 1
cfg.backendConfig.stmtNum = 3
},
func(cfg *testConfig) {
cfg.clientConfig.cmd = mysql.ComQuery
cfg.backendConfig.respondType = responseTypeLoadFile
cfg.backendConfig.stmtNum = 2
},
}

for _, cfg := range cfgs {
ts, clean := newTestSuite(t, tc, cfg)
ts.executeCmd(t, nil)
clean()
}
}
103 changes: 71 additions & 32 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type backendConfig struct {
params int
status uint16
loops int
stmtNum int
}

func newBackendConfig() *backendConfig {
Expand All @@ -48,6 +49,7 @@ func newBackendConfig() *backendConfig {
switchAuth: true,
authSucceed: true,
loops: 1,
stmtNum: 1,
}
}

Expand Down Expand Up @@ -147,7 +149,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error {
}
switch mb.respondType {
case responseTypeOK:
return packetIO.WriteOKPacket(mb.status)
return mb.respondOK(packetIO)
case responseTypeErr:
return packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrUnknown))
case responseTypeResultSet:
Expand Down Expand Up @@ -178,6 +180,21 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error {
return packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrUnknown))
}

func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error {
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
} else {
status &= ^mysql.ServerMoreResultsExists
}
if err := packetIO.WriteOKPacket(status); err != nil {
return err
}
}
return nil
}

// respond to FieldList
func (mb *mockBackend) respondColumns(packetIO *pnet.PacketIO) error {
for i := 0; i < mb.columns; i++ {
Expand Down Expand Up @@ -216,53 +233,75 @@ func (mb *mockBackend) respondResultSet(packetIO *pnet.PacketIO) error {
if err != nil {
return err
}
data := pnet.DumpLengthEncodedInt(nil, uint64(mb.columns))
if err := packetIO.WritePacket(data, false); err != nil {
return err
}
for _, field := range rs.Fields {
if err := packetIO.WritePacket(field.Dump(), false); err != nil {
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
} else {
status &= ^mysql.ServerMoreResultsExists
}
data := pnet.DumpLengthEncodedInt(nil, uint64(mb.columns))
if err := packetIO.WritePacket(data, false); err != nil {
return err
}
}
if err := packetIO.WriteEOFPacket(mb.status); err != nil {
return err
}
if mb.status&mysql.ServerStatusCursorExists == 0 {
for _, row := range values {
var data []byte
for _, value := range row {
data = pnet.DumpLengthEncodedString(data, []byte(value.(string)))
for _, field := range rs.Fields {
if err := packetIO.WritePacket(field.Dump(), false); err != nil {
return err
}
if err := packetIO.WritePacket(data, false); err != nil {
}
if err := packetIO.WriteEOFPacket(status); err != nil {
return err
}

if status&mysql.ServerStatusCursorExists == 0 {
for _, row := range values {
var data []byte
for _, value := range row {
data = pnet.DumpLengthEncodedString(data, []byte(value.(string)))
}
if err := packetIO.WritePacket(data, false); err != nil {
return err
}
}
if err := packetIO.WriteEOFPacket(status); err != nil {
return err
}
}
return packetIO.WriteEOFPacket(mb.status)
}
return nil
}

// respond to LoadInFile
func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error {
data := make([]byte, 0, 1+len(mockCmdStr))
data = append(data, mysql.LocalInFileHeader)
data = append(data, []byte(mockCmdStr)...)
if err := packetIO.WritePacket(data, true); err != nil {
return err
}
for {
// read file data
pkt, err := packetIO.ReadPacket()
if err != nil {
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
} else {
status &= ^mysql.ServerMoreResultsExists
}
data := make([]byte, 0, 1+len(mockCmdStr))
data = append(data, mysql.LocalInFileHeader)
data = append(data, []byte(mockCmdStr)...)
if err := packetIO.WritePacket(data, true); err != nil {
return err
}
// An empty packet indicates the end of file.
if len(pkt) == 0 {
break
for {
// read file data
pkt, err := packetIO.ReadPacket()
if err != nil {
return err
}
// An empty packet indicates the end of file.
if len(pkt) == 0 {
break
}
}
if err := packetIO.WriteOKPacket(status); err != nil {
return err
}
}
return packetIO.WriteOKPacket(mb.status)
return nil
}

// respond to Prepare
Expand Down
67 changes: 40 additions & 27 deletions pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,40 +281,53 @@ func (mc *mockClient) query(packetIO *pnet.PacketIO) error {
}

func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error {
pkt, err := packetIO.ReadPacket()
if err != nil {
return err
}
switch pkt[0] {
case mysql.OKHeader:
// check status
case mysql.ErrHeader:
return nil
case mysql.LocalInFileHeader:
for i := 0; i < mc.filePkts; i++ {
if err = packetIO.WritePacket(mockCmdBytes, false); err != nil {
return err
}
}
if err = packetIO.WritePacket(nil, true); err != nil {
return err
}
if _, err = packetIO.ReadPacket(); err != nil {
for {
var serverStatus uint16
pkt, err := packetIO.ReadPacket()
if err != nil {
return err
}
default:
// read result set
for {
switch pkt[0] {
case mysql.OKHeader:
serverStatus = binary.LittleEndian.Uint16(pkt[3:])
case mysql.ErrHeader:
return nil
case mysql.LocalInFileHeader:
for i := 0; i < mc.filePkts; i++ {
if err = packetIO.WritePacket(mockCmdBytes, false); err != nil {
return err
}
}
if err = packetIO.WritePacket(nil, true); err != nil {
return err
}
if pkt, err = packetIO.ReadPacket(); err != nil {
return err
}
if pnet.IsEOFPacket(pkt) {
break
if pkt[0] == mysql.OKHeader {
serverStatus = binary.LittleEndian.Uint16(pkt[3:])
} else {
return nil
}
default:
// read result set
for {
if pkt, err = packetIO.ReadPacket(); err != nil {
return err
}
if pnet.IsEOFPacket(pkt) {
break
}
}
serverStatus = binary.LittleEndian.Uint16(pkt[3:])
if serverStatus&mysql.ServerStatusCursorExists == 0 {
if err = mc.readErrOrUntilEOF(packetIO); err != nil {
return err
}
}
}
serverStatus := binary.LittleEndian.Uint16(pkt[3:])
if serverStatus&mysql.ServerStatusCursorExists == 0 {
return mc.readErrOrUntilEOF(packetIO)
if serverStatus&mysql.ServerMoreResultsExists == 0 {
break
}
}
return nil
Expand Down

0 comments on commit 930663c

Please sign in to comment.