Skip to content

Commit

Permalink
backend, net: Support compression protocol (#373)
Browse files Browse the repository at this point in the history
Co-authored-by: xhe <xw897002528@gmail.com>
  • Loading branch information
djshow832 and xhebox authored Oct 9, 2023
1 parent d3cc47c commit 419e26d
Show file tree
Hide file tree
Showing 19 changed files with 1,321 additions and 180 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/gin-gonic/gin v1.8.1
github.com/go-mysql-org/go-mysql v1.6.0
github.com/go-sql-driver/mysql v1.7.0
github.com/klauspost/compress v1.16.6
github.com/pingcap/tidb v1.1.0-beta.0.20230103132820-3ccff46aa3bc
github.com/pingcap/tidb/parser v0.0.0-20230103132820-3ccff46aa3bc
github.com/pingcap/tiproxy/lib v0.0.0-00010101000000-000000000000
Expand Down
3 changes: 2 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/compress v1.15.13 h1:NFn1Wr8cfnenSJSA46lLq4wHCcBzKTSjnBIexDMMOV0=
github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk=
github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
Expand Down
33 changes: 27 additions & 6 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ const defRequiredBackendCaps = pnet.ClientDeprecateEOF

// SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported.
// TiDB supports ClientDeprecateEOF since v6.3.0.
// TiDB supports ClientCompress and ClientZstdCompressionAlgorithm since v7.2.0.
const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB |
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag | pnet.ClientSSL |
pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements |
pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData |
requiredFrontendCaps | defRequiredBackendCaps
pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm | requiredFrontendCaps | defRequiredBackendCaps

// Authenticator handshakes with the client and the backend.
type Authenticator struct {
Expand All @@ -42,6 +43,7 @@ type Authenticator struct {
attrs map[string]string
salt []byte
capability pnet.Capability
zstdLevel int
collation uint8
proxyProtocol bool
requireBackendTLS bool
Expand All @@ -64,9 +66,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
}
// either from another proxy or directly from clients, we are acting as a proxy
proxy.Command = proxyprotocol.ProxyCommandProxy
if err := backendIO.WriteProxyV2(proxy); err != nil {
return err
}
backendIO.EnableProxyClient(proxy)
}
return nil
}
Expand Down Expand Up @@ -157,6 +157,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
auth.dbname = clientResp.DB
auth.collation = clientResp.Collation
auth.attrs = clientResp.Attrs
auth.zstdLevel = clientResp.ZstdLevel

// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second)
Expand Down Expand Up @@ -225,6 +226,12 @@ loop:
pktIdx++
switch serverPkt[0] {
case pnet.OKHeader.Byte():
if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil {
return err
}
if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
return err
}
return nil
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(serverPkt)
Expand Down Expand Up @@ -277,7 +284,10 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
return err
}

return auth.handleSecondAuthResult(backendIO)
if err = auth.handleSecondAuthResult(backendIO); err == nil {
return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel)
}
return err
}

func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
Expand Down Expand Up @@ -307,8 +317,9 @@ func (auth *Authenticator) writeAuthHandshake(
Attrs: auth.attrs,
Collation: auth.collation,
AuthData: authData,
Capability: auth.capability | authCap,
Capability: auth.capability&backendCapability | authCap,
AuthPlugin: authPlugin,
ZstdLevel: auth.zstdLevel,
}

if len(resp.Attrs) > 0 {
Expand Down Expand Up @@ -382,3 +393,13 @@ func (auth *Authenticator) changeUser(req *pnet.ChangeUserReq) {
func (auth *Authenticator) updateCurrentDB(db string) {
auth.dbname = db
}

func setCompress(packetIO *pnet.PacketIO, capability pnet.Capability, zstdLevel int) error {
algorithm := pnet.CompressionNone
if capability&pnet.ClientCompress > 0 {
algorithm = pnet.CompressionZlib
} else if capability&pnet.ClientZstdCompressionAlgorithm > 0 {
algorithm = pnet.CompressionZstd
}
return packetIO.SetCompressionAlgorithm(algorithm, zstdLevel)
}
159 changes: 159 additions & 0 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ func TestCapability(t *testing.T) {
cfg.clientConfig.capability |= pnet.ClientSecureConnection
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.capability &= ^pnet.ClientCompress
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.backendConfig.capability |= pnet.ClientCompress
cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
},
},
{
func(cfg *testConfig) {
cfg.clientConfig.capability &= ^pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
}

tc := newTCPConnSuite(t)
Expand Down Expand Up @@ -387,3 +411,138 @@ func TestProxyProtocol(t *testing.T) {
clean()
}
}

func TestCompressProtocol(t *testing.T) {
cfgs := [][]cfgOverrider{
{
func(cfg *testConfig) {
cfg.backendConfig.capability &= ^pnet.ClientCompress
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.backendConfig.capability |= pnet.ClientCompress
cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
},
},
{
func(cfg *testConfig) {
cfg.clientConfig.capability &= ^pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
cfg.clientConfig.zstdLevel = 3
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
cfg.clientConfig.zstdLevel = 9
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
}

checker := func(t *testing.T, ts *testSuite, referCfg *testConfig) {
// If the client enables compression, client <-> proxy enables compression.
if referCfg.clientConfig.capability&pnet.ClientCompress > 0 {
require.Greater(t, ts.mp.authenticator.capability&pnet.ClientCompress, pnet.Capability(0))
require.Greater(t, ts.mc.capability&pnet.ClientCompress, pnet.Capability(0))
} else {
require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientCompress)
require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientCompress)
}
// If both the client and the backend enables compression, proxy <-> backend enables compression.
if referCfg.clientConfig.capability&referCfg.backendConfig.capability&pnet.ClientCompress > 0 {
require.Greater(t, ts.mb.capability&pnet.ClientCompress, pnet.Capability(0))
} else {
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress)
}
// If the client enables zstd compression, client <-> proxy enables zstd compression.
zstdCap := pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm
if referCfg.clientConfig.capability&zstdCap == zstdCap {
require.Greater(t, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
require.Greater(t, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mp.authenticator.zstdLevel)
} else {
require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm)
require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientZstdCompressionAlgorithm)
}
// If both the client and the backend enables zstd compression, proxy <-> backend enables zstd compression.
if referCfg.clientConfig.capability&referCfg.backendConfig.capability&zstdCap == zstdCap {
require.Greater(t, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mb.zstdLevel)
} else {
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientZstdCompressionAlgorithm)
}
}

tc := newTCPConnSuite(t)
cfgOverriders := getCfgCombinations(cfgs)
for _, cfgs := range cfgOverriders {
referCfg := newTestConfig(cfgs...)
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
checker(t, ts, referCfg)
})
ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {
checker(t, ts, referCfg)
})
clean()
}
}

// After upgrading the backend, the backend capability may change.
func TestUpgradeBackendCap(t *testing.T) {
cfgs := [][]cfgOverrider{
{
func(cfg *testConfig) {
cfg.clientConfig.capability &= ^pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
cfg.clientConfig.zstdLevel = 3
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.capability &= ^pnet.ClientCompress
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
}

tc := newTCPConnSuite(t)
cfgOverriders := getCfgCombinations(cfgs)
for _, cfgs := range cfgOverriders {
referCfg := newTestConfig(cfgs...)
ts, clean := newTestSuite(t, tc, cfgs...)
// Before upgrade, the backend doesn't support compression.
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress)
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress)
})
// After upgrade, the backend also supports compression.
ts.mb.backendConfig.capability |= pnet.ClientCompress
ts.mb.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mb.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm)
})
clean()
}
}
20 changes: 9 additions & 11 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ func TestSpecialCmds(t *testing.T) {
require.NoError(t, ts.redirectSucceed4Backend(packetIO))
require.Equal(t, "another_user", ts.mb.username)
require.Equal(t, "session_db", ts.mb.db)
expectCap := pnet.Capability(ts.mp.handshakeHandler.GetCapability() &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData))
gotCap := pnet.Capability(ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData)
expectCap := ts.mp.handshakeHandler.GetCapability() & defaultTestClientCapability &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData)
gotCap := ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData
require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap)
return nil
},
Expand Down Expand Up @@ -793,18 +793,16 @@ func TestHandlerReturnError(t *testing.T) {
}

func TestOnTraffic(t *testing.T) {
i := 0
inbytes, outbytes := []int{
0x99,
}, []int{
0xce,
}
var inBytes, outBytes uint64
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
config.proxyConfig.handler.onTraffic = func(cc ConnContext) {
require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes())
require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes())
i++
require.Greater(t, cc.ClientInBytes(), uint64(0))
require.GreaterOrEqual(t, cc.ClientInBytes(), inBytes)
inBytes = cc.ClientInBytes()
require.Greater(t, cc.ClientOutBytes(), uint64(0))
require.GreaterOrEqual(t, cc.ClientOutBytes(), outBytes)
outBytes = cc.ClientOutBytes()
}
})
runners := []runner{
Expand Down
18 changes: 18 additions & 0 deletions pkg/proxy/backend/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() {
}
}

func (tc *tcpConnSuite) reconnectBackend(t *testing.T) {
lg, _ := logger.CreateLoggerForTest(t)
var wg waitgroup.WaitGroup
wg.Run(func() {
_ = tc.backendIO.Close()
conn, err := tc.backendListener.Accept()
require.NoError(t, err)
tc.backendIO = pnet.NewPacketIO(conn, lg)
})
wg.Run(func() {
_ = tc.proxyBIO.Close()
backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String())
require.NoError(t, err)
tc.proxyBIO = pnet.NewPacketIO(backendConn, lg)
})
wg.Wait()
}

func (tc *tcpConnSuite) run(clientRunner, backendRunner func(*pnet.PacketIO) error, proxyRunner func(*pnet.PacketIO, *pnet.PacketIO) error) (cerr, berr, perr error) {
var wg waitgroup.WaitGroup
if clientRunner != nil {
Expand Down
13 changes: 9 additions & 4 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ type mockBackend struct {
// Inputs that assigned by the test and will be sent to the client.
*backendConfig
// Outputs that received from the client and will be checked by the test.
username string
db string
attrs map[string]string
authData []byte
username string
db string
attrs map[string]string
authData []byte
zstdLevel int
}

func newMockBackend(cfg *backendConfig) *mockBackend {
Expand Down Expand Up @@ -98,6 +99,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
mb.authData = resp.AuthData
mb.attrs = resp.Attrs
mb.capability = resp.Capability
mb.zstdLevel = resp.ZstdLevel
// verify password
return mb.verifyPassword(packetIO, resp)
}
Expand Down Expand Up @@ -125,6 +127,9 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
if err := packetIO.WriteOKPacket(mb.status, pnet.OKHeader); err != nil {
return err
}
if err := setCompress(packetIO, mb.capability, mb.zstdLevel); err != nil {
return err
}
} else {
if err := packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR)); err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type clientConfig struct {
capability pnet.Capability
collation uint8
cmd pnet.Command
zstdLevel int
// for both auth and cmd
abnormalExit bool
}
Expand Down Expand Up @@ -82,6 +83,7 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error {
AuthData: mc.authData,
Capability: mc.capability,
Collation: mc.collation,
ZstdLevel: mc.zstdLevel,
}
pkt = pnet.MakeHandshakeResponse(resp)
if mc.capability&pnet.ClientSSL > 0 {
Expand Down
Loading

0 comments on commit 419e26d

Please sign in to comment.