Skip to content

Commit

Permalink
backend, net: optimize read/write connection by forwarding packets (#391
Browse files Browse the repository at this point in the history
)

Co-authored-by: xhe <xw897002528@gmail.com>
  • Loading branch information
djshow832 and xhebox authored Nov 2, 2023
1 parent 5da388a commit c7ea81a
Show file tree
Hide file tree
Showing 12 changed files with 753 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve
if serverPkt, err = backendIO.ReadPacket(); err != nil {
return
}
if pnet.IsErrorPacket(serverPkt) {
if pnet.IsErrorPacket(serverPkt[0]) {
err = pnet.ParseErrorPacket(serverPkt)
return
}
Expand Down
42 changes: 23 additions & 19 deletions pkg/proxy/backend/cmd_processor_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,32 @@ func forwardOnePacket(destIO, srcIO *pnet.PacketIO, flush bool) (data []byte, er
}

func (cp *CmdProcessor) forwardUntilResultEnd(clientIO, backendIO *pnet.PacketIO, request []byte) (uint16, error) {
for {
response, err := forwardOnePacket(clientIO, backendIO, false)
if err != nil {
return 0, err
var serverStatus uint16
err := backendIO.ForwardUntil(clientIO, func(firstByte byte, length int) bool {
switch {
case pnet.IsErrorPacket(firstByte):
return true
case cp.capability&pnet.ClientDeprecateEOF == 0:
return pnet.IsEOFPacket(firstByte, length)
default:
return pnet.IsResultSetOKPacket(firstByte, length)
}
if pnet.IsErrorPacket(response) {
}, func(response []byte) error {
switch {
case pnet.IsErrorPacket(response[0]):
if err := clientIO.Flush(); err != nil {
return 0, err
}
return 0, cp.handleErrorPacket(response)
}
if cp.capability&pnet.ClientDeprecateEOF == 0 {
if pnet.IsEOFPacket(response) {
return cp.handleEOFPacket(request, response), clientIO.Flush()
}
} else {
if pnet.IsResultSetOKPacket(response) {
rs := cp.handleOKPacket(request, response)
return rs.Status, clientIO.Flush()
return err
}
return cp.handleErrorPacket(response)
case cp.capability&pnet.ClientDeprecateEOF == 0:
serverStatus = cp.handleEOFPacket(request, response)
return clientIO.Flush()
default:
serverStatus = cp.handleOKPacket(request, response).Status
return clientIO.Flush()
}
}
})
return serverStatus, err
}

func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -241,7 +245,7 @@ func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, req
if response, err = forwardOnePacket(clientIO, backendIO, false); err != nil {
return 0, err
}
if pnet.IsEOFPacket(response) {
if pnet.IsEOFPacket(response[0], len(response)) {
break
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/proxy/backend/cmd_processor_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *gomys
if data, err = packetIO.ReadPacket(); err != nil {
return err
}
if !pnet.IsEOFPacket(data) {
if !pnet.IsEOFPacket(data[0], len(data)) {
return errors.WithStack(mysql.ErrMalformPacket)
}
result.Status = binary.LittleEndian.Uint16(data[3:])
Expand Down Expand Up @@ -103,19 +103,19 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql.
return err
}
if cp.capability&pnet.ClientDeprecateEOF == 0 {
if pnet.IsEOFPacket(data) {
if pnet.IsEOFPacket(data[0], len(data)) {
result.Status = binary.LittleEndian.Uint16(data[3:])
break
}
} else {
if pnet.IsResultSetOKPacket(data) {
if pnet.IsResultSetOKPacket(data[0], len(data)) {
rs := pnet.ParseOKPacket(data)
result.Status = rs.Status
break
}
}
// An error may occur when the backend writes rows.
if pnet.IsErrorPacket(data) {
if pnet.IsErrorPacket(data[0]) {
return cp.handleErrorPacket(data)
}
result.RowDatas = append(result.RowDatas, data)
Expand Down
4 changes: 2 additions & 2 deletions pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,11 @@ func (mc *mockClient) readUntilResultEnd(packetIO *pnet.PacketIO) (pkt []byte, e
return
}
if mc.capability&pnet.ClientDeprecateEOF == 0 {
if pnet.IsEOFPacket(pkt) {
if pnet.IsEOFPacket(pkt[0], len(pkt)) {
break
}
} else {
if pnet.IsResultSetOKPacket(pkt) {
if pnet.IsResultSetOKPacket(pkt[0], len(pkt)) {
break
}
}
Expand Down
22 changes: 19 additions & 3 deletions pkg/proxy/net/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ func (crw *compressedReadWriter) Read(p []byte) (n int, err error) {
return
}

func (crw *compressedReadWriter) ReadFrom(r io.Reader) (n int64, err error) {
// TODO: copy compressed data directly.
buf := make([]byte, DefaultConnBufferSize)
nn := 0
for {
nn, err = r.Read(buf)
if (err == nil || err == io.EOF) && nn > 0 {
_, err = crw.Write(buf[:nn])
n += int64(nn)
}
if err == io.EOF {
return n, nil
} else if err != nil {
return n, err
}
}
}

// Read and uncompress the data into readBuffer.
// The format of the protocol: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_compression_packet.html
func (crw *compressedReadWriter) readFromConn() error {
Expand Down Expand Up @@ -227,9 +245,7 @@ func (crw *compressedReadWriter) Peek(n int) (data []byte, err error) {
return
}
}
data = make([]byte, 0, n)
copy(data, crw.readBuffer.Bytes())
return
return crw.readBuffer.Bytes()[:n], nil
}

// Discard won't be used.
Expand Down
12 changes: 6 additions & 6 deletions pkg/proxy/net/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
)

var (
ErrExpectSSLRequest = errors.New("expect a SSLRequest packet")
ErrReadConn = errors.New("failed to read the connection")
ErrWriteConn = errors.New("failed to write the connection")
ErrFlushConn = errors.New("failed to flush the connection")
ErrCloseConn = errors.New("failed to close the connection")
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
ErrReadConn = errors.New("failed to read the connection")
ErrWriteConn = errors.New("failed to write the connection")
ErrRelayConn = errors.New("failed to relay the connection")
ErrFlushConn = errors.New("failed to flush the connection")
ErrCloseConn = errors.New("failed to close the connection")
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
)

// UserError is returned to the client.
Expand Down
16 changes: 8 additions & 8 deletions pkg/proxy/net/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,24 +412,24 @@ func ParseErrorPacket(data []byte) error {
}

// IsOKPacket returns true if it's an OK packet (but not ResultSet OK).
func IsOKPacket(data []byte) bool {
return data[0] == OKHeader.Byte()
func IsOKPacket(firstByte byte) bool {
return firstByte == OKHeader.Byte()
}

// IsEOFPacket returns true if it's an EOF packet.
func IsEOFPacket(data []byte) bool {
return data[0] == EOFHeader.Byte() && len(data) <= 5
func IsEOFPacket(firstByte byte, length int) bool {
return firstByte == EOFHeader.Byte() && length <= 5
}

// IsResultSetOKPacket returns true if it's an OK packet after the result set when CLIENT_DEPRECATE_EOF is enabled.
// A row packet may also begin with 0xfe, so we need to judge it with the packet length.
// See https://mariadb.com/kb/en/result-set-packets/
func IsResultSetOKPacket(data []byte) bool {
func IsResultSetOKPacket(firstByte byte, length int) bool {
// With CLIENT_PROTOCOL_41 enabled, the least length is 7.
return data[0] == EOFHeader.Byte() && len(data) >= 7 && len(data) < 0xFFFFFF
return firstByte == EOFHeader.Byte() && length >= 7 && length < 0xFFFFFF
}

// IsErrorPacket returns true if it's an error packet.
func IsErrorPacket(data []byte) bool {
return data[0] == ErrHeader.Byte()
func IsErrorPacket(firstByte byte) bool {
return firstByte == ErrHeader.Byte()
}
48 changes: 46 additions & 2 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
package net

import (
"bufio"
"crypto/tls"
"io"
"net"
Expand All @@ -37,6 +36,7 @@ import (
"github.com/pingcap/tiproxy/lib/util/errors"
"github.com/pingcap/tiproxy/pkg/proxy/keepalive"
"github.com/pingcap/tiproxy/pkg/proxy/proxyprotocol"
"github.com/pingcap/tiproxy/pkg/util/bufio"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -64,6 +64,7 @@ type packetReadWriter interface {
Discard(n int) (int, error)
Flush() error
DirectWrite(p []byte) (int, error)
ReadFrom(r io.Reader) (int64, error)
Proxy() *proxyprotocol.Proxy
TLSConnectionState() tls.ConnectionState
InBytes() uint64
Expand Down Expand Up @@ -110,6 +111,12 @@ func (brw *basicReadWriter) Write(p []byte) (int, error) {
return n, errors.WithStack(err)
}

func (brw *basicReadWriter) ReadFrom(r io.Reader) (int64, error) {
n, err := brw.ReadWriter.ReadFrom(r)
brw.outBytes += uint64(n)
return n, errors.WithStack(err)
}

func (brw *basicReadWriter) DirectWrite(p []byte) (int, error) {
n, err := brw.Conn.Write(p)
brw.outBytes += uint64(n)
Expand Down Expand Up @@ -187,7 +194,8 @@ type PacketIO struct {
lastKeepAlive config.KeepAlive
rawConn net.Conn
readWriter packetReadWriter
header []byte
limitReader io.LimitedReader // reuse memory to reduce allocation
header []byte // reuse memory to reduce allocation
logger *zap.Logger
remoteAddr net.Addr
wrap error
Expand Down Expand Up @@ -317,6 +325,42 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) {
return nil
}

func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, firstPktLen int) bool, process func(response []byte) error) error {
p.readWriter.BeginRW(rwRead)
dest.readWriter.BeginRW(rwWrite)
p.limitReader.R = p.readWriter
for {
header, err := p.readWriter.Peek(5)
if err != nil {
return errors.Wrap(ErrReadConn, err)
}
length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16
if isEnd(header[4], length) {
// TODO: allocate a buffer from pool and return the buffer after `process`.
data, err := p.ReadPacket()
if err != nil {
return errors.Wrap(ErrReadConn, err)
}
if err := dest.WritePacket(data, false); err != nil {
return errors.Wrap(ErrWriteConn, err)
}
return process(data)
} else {
sequence, pktSequence := header[3], p.readWriter.Sequence()
if sequence != pktSequence {
return ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)
}
p.readWriter.SetSequence(sequence + 1)
// Sequence may be different (e.g. with compression) so we can't just copy the data to the destination.
dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1)
p.limitReader.N = int64(length + 4)
if _, err := dest.readWriter.ReadFrom(&p.limitReader); err != nil {
return errors.Wrap(ErrRelayConn, err)
}
}
}
}

func (p *PacketIO) InBytes() uint64 {
return p.readWriter.InBytes()
}
Expand Down
Loading

0 comments on commit c7ea81a

Please sign in to comment.