Skip to content

Commit

Permalink
net, config: optimize read and write packets (#382)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Oct 18, 2023
1 parent a4d2de9 commit 35f3c79
Show file tree
Hide file tree
Showing 16 changed files with 245 additions and 91 deletions.
6 changes: 6 additions & 0 deletions conf/proxy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
# 100 => accept as many as 100 connections.
# max-connections = 0

# It's a tradeoff between memory and performance.
# possible values:
# 0 => default value
# 1K to 16M
# conn-buffer-size = 0

[api]
# addr = "0.0.0.0:3080"

Expand Down
5 changes: 5 additions & 0 deletions lib/config/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

var (
ErrUnsupportedProxyProtocolVersion = errors.New("unsupported proxy protocol version")
ErrInvalidConfigValue = errors.New("invalid config value")
)

type Config struct {
Expand Down Expand Up @@ -46,6 +47,7 @@ type KeepAlive struct {

type ProxyServerOnline struct {
MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty"`
ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty"`
FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"`
// BackendHealthyKeepalive applies when the observer treats the backend as healthy.
// The config values should be conservative to save CPU and tolerate network fluctuation.
Expand Down Expand Up @@ -182,6 +184,9 @@ func (cfg *Config) Check() error {
return errors.Wrapf(ErrUnsupportedProxyProtocolVersion, "%s", cfg.Proxy.ProxyProtocol)
}

if cfg.Proxy.ConnBufferSize > 0 && (cfg.Proxy.ConnBufferSize > 16*1024*1024 || cfg.Proxy.ConnBufferSize < 1024) {
return errors.Wrapf(ErrInvalidConfigValue, "conn-buffer-size must be between 1K and 16M")
}
return nil
}

Expand Down
7 changes: 7 additions & 0 deletions lib/config/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var testProxyConfig = Config{
FrontendKeepalive: KeepAlive{Enabled: true},
ProxyProtocol: "v2",
GracefulWaitBeforeShutdown: 10,
ConnBufferSize: 32 * 1024,
},
},
API: API{
Expand Down Expand Up @@ -113,6 +114,12 @@ func TestProxyCheck(t *testing.T) {
},
err: ErrUnsupportedProxyProtocolVersion,
},
{
pre: func(t *testing.T, c *Config) {
c.Proxy.ConnBufferSize = 100 * 1024 * 1024
},
err: ErrInvalidConfigValue,
},
}
for _, tc := range testcases {
cfg := testProxyConfig
Expand Down
11 changes: 6 additions & 5 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ const (
)

type BCConfig struct {
ProxyProtocol bool
RequireBackendTLS bool
CheckBackendInterval time.Duration
HealthyKeepAlive config.KeepAlive
UnhealthyKeepAlive config.KeepAlive
CheckBackendInterval time.Duration
ConnBufferSize int
ProxyProtocol bool
RequireBackendTLS bool
}

func (cfg *BCConfig) check() {
Expand Down Expand Up @@ -219,7 +220,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
backendIO := pnet.NewPacketIO(cn, mgr.logger, pnet.WithRemoteAddr(addr, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))
backendIO := pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr(addr, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))
mgr.backendIO.Store(backendIO)
mgr.setKeepAlive(mgr.config.HealthyKeepAlive)
return backendIO, nil
Expand Down Expand Up @@ -442,7 +443,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err)
return
}
newBackendIO := pnet.NewPacketIO(cn, mgr.logger, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))
newBackendIO := pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))

if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil {
rs.err = mgr.initSessionStates(newBackendIO, sessionStates)
Expand Down
8 changes: 4 additions & 4 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (ts *backendMgrTester) firstHandshake4Proxy(clientIO, backendIO *pnet.Packe
func (ts *backendMgrTester) handshake4Backend(packetIO *pnet.PacketIO) error {
conn, err := ts.tc.backendListener.Accept()
require.NoError(ts.t, err)
ts.tc.backendIO = pnet.NewPacketIO(conn, ts.lg)
ts.tc.backendIO = pnet.NewPacketIO(conn, ts.lg, pnet.DefaultConnBufferSize)
return ts.mb.authenticate(ts.tc.backendIO)
}

Expand Down Expand Up @@ -404,7 +404,7 @@ func TestConnectFail(t *testing.T) {
backend: func(_ *pnet.PacketIO) error {
conn, err := ts.tc.backendListener.Accept()
require.NoError(ts.t, err)
ts.tc.backendIO = pnet.NewPacketIO(conn, ts.lg)
ts.tc.backendIO = pnet.NewPacketIO(conn, ts.lg, pnet.DefaultConnBufferSize)
ts.mb.authSucceed = false
return ts.mb.authenticate(ts.tc.backendIO)
},
Expand Down Expand Up @@ -448,7 +448,7 @@ func TestRedirectFail(t *testing.T) {
require.NoError(t, err)
conn, err := ts.tc.backendListener.Accept()
require.NoError(t, err)
tmpBackendIO := pnet.NewPacketIO(conn, ts.lg)
tmpBackendIO := pnet.NewPacketIO(conn, ts.lg, pnet.DefaultConnBufferSize)
// auth fails
ts.mb.authSucceed = false
err = ts.mb.authenticate(tmpBackendIO)
Expand All @@ -469,7 +469,7 @@ func TestRedirectFail(t *testing.T) {
require.NoError(ts.t, err)
conn, err := ts.tc.backendListener.Accept()
require.NoError(ts.t, err)
tmpBackendIO := pnet.NewPacketIO(conn, ts.lg)
tmpBackendIO := pnet.NewPacketIO(conn, ts.lg, pnet.DefaultConnBufferSize)
ts.mb.authSucceed = true
err = ts.mb.authenticate(tmpBackendIO)
require.NoError(t, err)
Expand Down
12 changes: 6 additions & 6 deletions pkg/proxy/backend/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,23 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() {
wg.Run(func() {
conn, err := tc.backendListener.Accept()
require.NoError(t, err)
tc.backendIO = pnet.NewPacketIO(conn, lg)
tc.backendIO = pnet.NewPacketIO(conn, lg, pnet.DefaultConnBufferSize)
})
}
wg.Run(func() {
if !enableRoute {
backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String())
require.NoError(t, err)
tc.proxyBIO = pnet.NewPacketIO(backendConn, lg)
tc.proxyBIO = pnet.NewPacketIO(backendConn, lg, pnet.DefaultConnBufferSize)
}
clientConn, err := tc.proxyListener.Accept()
require.NoError(t, err)
tc.proxyCIO = pnet.NewPacketIO(clientConn, lg)
tc.proxyCIO = pnet.NewPacketIO(clientConn, lg, pnet.DefaultConnBufferSize)
})
wg.Run(func() {
conn, err := net.Dial("tcp", tc.proxyListener.Addr().String())
require.NoError(t, err)
tc.clientIO = pnet.NewPacketIO(conn, lg)
tc.clientIO = pnet.NewPacketIO(conn, lg, pnet.DefaultConnBufferSize)
})
wg.Wait()
return func() {
Expand All @@ -91,13 +91,13 @@ func (tc *tcpConnSuite) reconnectBackend(t *testing.T) {
_ = tc.backendIO.Close()
conn, err := tc.backendListener.Accept()
require.NoError(t, err)
tc.backendIO = pnet.NewPacketIO(conn, lg)
tc.backendIO = pnet.NewPacketIO(conn, lg, pnet.DefaultConnBufferSize)
})
wg.Run(func() {
_ = tc.proxyBIO.Close()
backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String())
require.NoError(t, err)
tc.proxyBIO = pnet.NewPacketIO(backendConn, lg)
tc.proxyBIO = pnet.NewPacketIO(backendConn, lg, pnet.DefaultConnBufferSize)
})
wg.Wait()
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *t
if bcConfig.ProxyProtocol {
opts = append(opts, pnet.WithProxy)
}
pkt := pnet.NewPacketIO(conn, logger, opts...)
pkt := pnet.NewPacketIO(conn, logger, bcConfig.ConnBufferSize, opts...)
return &ClientConnection{
logger: logger,
frontendTLSConfig: frontendTLSConfig,
Expand Down
34 changes: 17 additions & 17 deletions pkg/proxy/net/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type compressedReadWriter struct {
logger *zap.Logger
rwStatus rwStatus
zstdLevel zstd.EncoderLevel
header []byte
sequence uint8
}

Expand All @@ -70,6 +71,7 @@ func newCompressedReadWriter(rw packetReadWriter, algorithm CompressAlgorithm, z
zstdLevel: zstd.EncoderLevelFromZstd(zstdLevel),
logger: logger,
rwStatus: rwNone,
header: make([]byte, 7),
}
}

Expand Down Expand Up @@ -100,7 +102,7 @@ func (crw *compressedReadWriter) Read(p []byte) (n int, err error) {
}
n, err = crw.readBuffer.Read(p)
// Trade off between memory and efficiency.
if n == len(p) && crw.readBuffer.Len() == 0 && crw.readBuffer.Cap() > defaultReaderSize {
if n == len(p) && crw.readBuffer.Len() == 0 && crw.readBuffer.Cap() > DefaultConnBufferSize {
crw.readBuffer = bytes.Buffer{}
}
return
Expand All @@ -110,18 +112,17 @@ func (crw *compressedReadWriter) Read(p []byte) (n int, err error) {
// The format of the protocol: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_compression_packet.html
func (crw *compressedReadWriter) readFromConn() error {
var err error
var header [7]byte
if _, err = io.ReadFull(crw.packetReadWriter, header[:]); err != nil {
if err = ReadFull(crw.packetReadWriter, crw.header); err != nil {
return err
}
compressedSequence := header[3]
compressedSequence := crw.header[3]
if compressedSequence != crw.sequence {
return ErrInvalidSequence.GenWithStack(
"invalid compressed sequence, expected %d, actual %d", crw.sequence, compressedSequence)
}
crw.sequence++
compressedLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16)
compressedLength := int(uint32(crw.header[0]) | uint32(crw.header[1])<<8 | uint32(crw.header[2])<<16)
uncompressedLength := int(uint32(crw.header[4]) | uint32(crw.header[5])<<8 | uint32(crw.header[6])<<16)

if uncompressedLength == 0 {
// If the data is uncompressed, the uncompressed length is 0 and compressed length is the data length
Expand All @@ -134,7 +135,7 @@ func (crw *compressedReadWriter) readFromConn() error {
// If the data is compressed, the compressed length is the length of data after the compressed header and
// the uncompressed length is the length of data after decompression.
data := make([]byte, compressedLength)
if _, err = io.ReadFull(crw.packetReadWriter, data); err != nil {
if err = ReadFull(crw.packetReadWriter, data); err != nil {
return err
}
if err = crw.uncompress(data, uncompressedLength); err != nil {
Expand Down Expand Up @@ -173,7 +174,7 @@ func (crw *compressedReadWriter) Flush() error {
return nil
}
// Trade off between memory and efficiency.
if crw.writeBuffer.Cap() > defaultWriterSize {
if crw.writeBuffer.Cap() > DefaultConnBufferSize {
crw.writeBuffer = bytes.Buffer{}
} else {
crw.writeBuffer.Reset()
Expand All @@ -193,16 +194,15 @@ func (crw *compressedReadWriter) Flush() error {
compressedLength = len(data)
}

var compressedHeader [7]byte
compressedHeader[0] = byte(compressedLength)
compressedHeader[1] = byte(compressedLength >> 8)
compressedHeader[2] = byte(compressedLength >> 16)
compressedHeader[3] = crw.sequence
compressedHeader[4] = byte(uncompressedLength)
compressedHeader[5] = byte(uncompressedLength >> 8)
compressedHeader[6] = byte(uncompressedLength >> 16)
crw.header[0] = byte(compressedLength)
crw.header[1] = byte(compressedLength >> 8)
crw.header[2] = byte(compressedLength >> 16)
crw.header[3] = crw.sequence
crw.header[4] = byte(uncompressedLength)
crw.header[5] = byte(uncompressedLength >> 8)
crw.header[6] = byte(uncompressedLength >> 16)
crw.sequence++
if _, err = crw.packetReadWriter.Write(compressedHeader[:]); err != nil {
if _, err = crw.packetReadWriter.Write(crw.header[:]); err != nil {
return errors.WithStack(err)
}
if _, err = crw.packetReadWriter.Write(data); err != nil {
Expand Down
Loading

0 comments on commit 35f3c79

Please sign in to comment.