diff --git a/go.mod b/go.mod index b514da75..ab8f3965 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gin-gonic/gin v1.8.1 github.com/go-mysql-org/go-mysql v1.6.0 github.com/go-sql-driver/mysql v1.7.0 + github.com/klauspost/compress v1.16.6 github.com/pingcap/tidb v1.1.0-beta.0.20230103132820-3ccff46aa3bc github.com/pingcap/tidb/parser v0.0.0-20230103132820-3ccff46aa3bc github.com/pingcap/tiproxy/lib v0.0.0-00010101000000-000000000000 diff --git a/go.sum b/go.sum index 2f2ce436..845c5ffa 100644 --- a/go.sum +++ b/go.sum @@ -415,7 +415,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= -github.com/klauspost/compress v1.15.13 h1:NFn1Wr8cfnenSJSA46lLq4wHCcBzKTSjnBIexDMMOV0= +github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk= +github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 3841a09c..2889bdce 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -29,11 +29,12 @@ const defRequiredBackendCaps = pnet.ClientDeprecateEOF // SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported. // TiDB supports ClientDeprecateEOF since v6.3.0. +// TiDB supports ClientCompress and ClientZstdCompressionAlgorithm since v7.2.0. const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB | pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag | pnet.ClientSSL | pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements | pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData | - requiredFrontendCaps | defRequiredBackendCaps + pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm | requiredFrontendCaps | defRequiredBackendCaps // Authenticator handshakes with the client and the backend. type Authenticator struct { @@ -42,6 +43,7 @@ type Authenticator struct { attrs map[string]string salt []byte capability pnet.Capability + zstdLevel int collation uint8 proxyProtocol bool requireBackendTLS bool @@ -64,9 +66,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO } // either from another proxy or directly from clients, we are acting as a proxy proxy.Command = proxyprotocol.ProxyCommandProxy - if err := backendIO.WriteProxyV2(proxy); err != nil { - return err - } + backendIO.EnableProxyClient(proxy) } return nil } @@ -157,6 +157,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte auth.dbname = clientResp.DB auth.collation = clientResp.Collation auth.attrs = clientResp.Attrs + auth.zstdLevel = clientResp.ZstdLevel // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second) @@ -225,6 +226,12 @@ loop: pktIdx++ switch serverPkt[0] { case pnet.OKHeader.Byte(): + if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil { + return err + } + if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil { + return err + } return nil case pnet.ErrHeader.Byte(): return pnet.ParseErrorPacket(serverPkt) @@ -277,7 +284,10 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac return err } - return auth.handleSecondAuthResult(backendIO) + if err = auth.handleSecondAuthResult(backendIO); err == nil { + return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel) + } + return err } func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) { @@ -307,8 +317,9 @@ func (auth *Authenticator) writeAuthHandshake( Attrs: auth.attrs, Collation: auth.collation, AuthData: authData, - Capability: auth.capability | authCap, + Capability: auth.capability&backendCapability | authCap, AuthPlugin: authPlugin, + ZstdLevel: auth.zstdLevel, } if len(resp.Attrs) > 0 { @@ -382,3 +393,13 @@ func (auth *Authenticator) changeUser(req *pnet.ChangeUserReq) { func (auth *Authenticator) updateCurrentDB(db string) { auth.dbname = db } + +func setCompress(packetIO *pnet.PacketIO, capability pnet.Capability, zstdLevel int) error { + algorithm := pnet.CompressionNone + if capability&pnet.ClientCompress > 0 { + algorithm = pnet.CompressionZlib + } else if capability&pnet.ClientZstdCompressionAlgorithm > 0 { + algorithm = pnet.CompressionZstd + } + return packetIO.SetCompressionAlgorithm(algorithm, zstdLevel) +} diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 9fcbc89c..47d3e99d 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -164,6 +164,30 @@ func TestCapability(t *testing.T) { cfg.clientConfig.capability |= pnet.ClientSecureConnection }, }, + { + func(cfg *testConfig) { + cfg.backendConfig.capability &= ^pnet.ClientCompress + cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + func(cfg *testConfig) { + cfg.backendConfig.capability |= pnet.ClientCompress + cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm + }, + }, + { + func(cfg *testConfig) { + cfg.clientConfig.capability &= ^pnet.ClientCompress + cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + func(cfg *testConfig) { + cfg.clientConfig.capability |= pnet.ClientCompress + cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm + }, + func(cfg *testConfig) { + cfg.clientConfig.capability |= pnet.ClientCompress + cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + }, } tc := newTCPConnSuite(t) @@ -387,3 +411,138 @@ func TestProxyProtocol(t *testing.T) { clean() } } + +func TestCompressProtocol(t *testing.T) { + cfgs := [][]cfgOverrider{ + { + func(cfg *testConfig) { + cfg.backendConfig.capability &= ^pnet.ClientCompress + cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + func(cfg *testConfig) { + cfg.backendConfig.capability |= pnet.ClientCompress + cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm + }, + }, + { + func(cfg *testConfig) { + cfg.clientConfig.capability &= ^pnet.ClientCompress + cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + func(cfg *testConfig) { + cfg.clientConfig.capability |= pnet.ClientCompress + cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm + cfg.clientConfig.zstdLevel = 3 + }, + func(cfg *testConfig) { + cfg.clientConfig.capability |= pnet.ClientCompress + cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm + cfg.clientConfig.zstdLevel = 9 + }, + func(cfg *testConfig) { + cfg.clientConfig.capability |= pnet.ClientCompress + cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + }, + } + + checker := func(t *testing.T, ts *testSuite, referCfg *testConfig) { + // If the client enables compression, client <-> proxy enables compression. + if referCfg.clientConfig.capability&pnet.ClientCompress > 0 { + require.Greater(t, ts.mp.authenticator.capability&pnet.ClientCompress, pnet.Capability(0)) + require.Greater(t, ts.mc.capability&pnet.ClientCompress, pnet.Capability(0)) + } else { + require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientCompress) + require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientCompress) + } + // If both the client and the backend enables compression, proxy <-> backend enables compression. + if referCfg.clientConfig.capability&referCfg.backendConfig.capability&pnet.ClientCompress > 0 { + require.Greater(t, ts.mb.capability&pnet.ClientCompress, pnet.Capability(0)) + } else { + require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress) + } + // If the client enables zstd compression, client <-> proxy enables zstd compression. + zstdCap := pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm + if referCfg.clientConfig.capability&zstdCap == zstdCap { + require.Greater(t, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0)) + require.Greater(t, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0)) + require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mp.authenticator.zstdLevel) + } else { + require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm) + require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientZstdCompressionAlgorithm) + } + // If both the client and the backend enables zstd compression, proxy <-> backend enables zstd compression. + if referCfg.clientConfig.capability&referCfg.backendConfig.capability&zstdCap == zstdCap { + require.Greater(t, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0)) + require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mb.zstdLevel) + } else { + require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientZstdCompressionAlgorithm) + } + } + + tc := newTCPConnSuite(t) + cfgOverriders := getCfgCombinations(cfgs) + for _, cfgs := range cfgOverriders { + referCfg := newTestConfig(cfgs...) + ts, clean := newTestSuite(t, tc, cfgs...) + ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { + checker(t, ts, referCfg) + }) + ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) { + checker(t, ts, referCfg) + }) + clean() + } +} + +// After upgrading the backend, the backend capability may change. +func TestUpgradeBackendCap(t *testing.T) { + cfgs := [][]cfgOverrider{ + { + func(cfg *testConfig) { + cfg.clientConfig.capability &= ^pnet.ClientCompress + cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + func(cfg *testConfig) { + cfg.clientConfig.capability |= pnet.ClientCompress + cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm + cfg.clientConfig.zstdLevel = 3 + }, + func(cfg *testConfig) { + cfg.clientConfig.capability |= pnet.ClientCompress + cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + }, + { + func(cfg *testConfig) { + cfg.backendConfig.capability &= ^pnet.ClientCompress + cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm + }, + }, + } + + tc := newTCPConnSuite(t) + cfgOverriders := getCfgCombinations(cfgs) + for _, cfgs := range cfgOverriders { + referCfg := newTestConfig(cfgs...) + ts, clean := newTestSuite(t, tc, cfgs...) + // Before upgrade, the backend doesn't support compression. + ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress) + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress) + require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress) + }) + // After upgrade, the backend also supports compression. + ts.mb.backendConfig.capability |= pnet.ClientCompress + ts.mb.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm + ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) { + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress) + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress) + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mb.capability&pnet.ClientCompress) + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm) + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm) + require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm) + }) + clean() + } +} diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index b7ea68e4..6a76347d 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -527,8 +527,8 @@ func TestSpecialCmds(t *testing.T) { require.NoError(t, ts.redirectSucceed4Backend(packetIO)) require.Equal(t, "another_user", ts.mb.username) require.Equal(t, "session_db", ts.mb.db) - expectCap := pnet.Capability(ts.mp.handshakeHandler.GetCapability() &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData)) - gotCap := pnet.Capability(ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData) + expectCap := ts.mp.handshakeHandler.GetCapability() & defaultTestClientCapability &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData) + gotCap := ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap) return nil }, @@ -793,18 +793,16 @@ func TestHandlerReturnError(t *testing.T) { } func TestOnTraffic(t *testing.T) { - i := 0 - inbytes, outbytes := []int{ - 0x99, - }, []int{ - 0xce, - } + var inBytes, outBytes uint64 ts := newBackendMgrTester(t, func(config *testConfig) { config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond config.proxyConfig.handler.onTraffic = func(cc ConnContext) { - require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes()) - require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes()) - i++ + require.Greater(t, cc.ClientInBytes(), uint64(0)) + require.GreaterOrEqual(t, cc.ClientInBytes(), inBytes) + inBytes = cc.ClientInBytes() + require.Greater(t, cc.ClientOutBytes(), uint64(0)) + require.GreaterOrEqual(t, cc.ClientOutBytes(), outBytes) + outBytes = cc.ClientOutBytes() } }) runners := []runner{ diff --git a/pkg/proxy/backend/common_test.go b/pkg/proxy/backend/common_test.go index bf5cfdd3..ead11711 100644 --- a/pkg/proxy/backend/common_test.go +++ b/pkg/proxy/backend/common_test.go @@ -84,6 +84,24 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() { } } +func (tc *tcpConnSuite) reconnectBackend(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + var wg waitgroup.WaitGroup + wg.Run(func() { + _ = tc.backendIO.Close() + conn, err := tc.backendListener.Accept() + require.NoError(t, err) + tc.backendIO = pnet.NewPacketIO(conn, lg) + }) + wg.Run(func() { + _ = tc.proxyBIO.Close() + backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String()) + require.NoError(t, err) + tc.proxyBIO = pnet.NewPacketIO(backendConn, lg) + }) + wg.Wait() +} + func (tc *tcpConnSuite) run(clientRunner, backendRunner func(*pnet.PacketIO) error, proxyRunner func(*pnet.PacketIO, *pnet.PacketIO) error) (cerr, berr, perr error) { var wg waitgroup.WaitGroup if clientRunner != nil { diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index cc7d453a..303a87f1 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -47,10 +47,11 @@ type mockBackend struct { // Inputs that assigned by the test and will be sent to the client. *backendConfig // Outputs that received from the client and will be checked by the test. - username string - db string - attrs map[string]string - authData []byte + username string + db string + attrs map[string]string + authData []byte + zstdLevel int } func newMockBackend(cfg *backendConfig) *mockBackend { @@ -98,6 +99,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { mb.authData = resp.AuthData mb.attrs = resp.Attrs mb.capability = resp.Capability + mb.zstdLevel = resp.ZstdLevel // verify password return mb.verifyPassword(packetIO, resp) } @@ -125,6 +127,9 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh if err := packetIO.WriteOKPacket(mb.status, pnet.OKHeader); err != nil { return err } + if err := setCompress(packetIO, mb.capability, mb.zstdLevel); err != nil { + return err + } } else { if err := packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR)); err != nil { return err diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index de28a2c8..188f9da4 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -26,6 +26,7 @@ type clientConfig struct { capability pnet.Capability collation uint8 cmd pnet.Command + zstdLevel int // for both auth and cmd abnormalExit bool } @@ -82,6 +83,7 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error { AuthData: mc.authData, Capability: mc.capability, Collation: mc.collation, + ZstdLevel: mc.zstdLevel, } pkt = pnet.MakeHandshakeResponse(resp) if mc.capability&pnet.ClientSSL > 0 { diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 13984e3b..f323e525 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -22,13 +22,13 @@ import ( // sent from the server and vice versa. const ( - defaultTestBackendCapability = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientLongFlag | + defaultTestClientCapability = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientLongFlag | pnet.ClientConnectWithDB | pnet.ClientNoSchema | pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientIgnoreSpace | pnet.ClientProtocol41 | pnet.ClientInteractive | pnet.ClientSSL | pnet.ClientIgnoreSigpipe | pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements | pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData | pnet.ClientDeprecateEOF - defaultTestClientCapability = defaultTestBackendCapability + defaultTestBackendCapability = defaultTestClientCapability | pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm ) var ( @@ -197,6 +197,7 @@ func (ts *testSuite) authenticateFirstTime(t *testing.T, c checker) { // This must be called after authenticateFirstTime. func (ts *testSuite) authenticateSecondTime(t *testing.T, c checker) { ts.mb.backendConfig.authSucceed = true + ts.tc.reconnectBackend(t) ts.runAndCheck(t, c, nil, ts.mb.authenticate, ts.mp.authenticateSecondTime) if c == nil { require.Equal(t, ts.mc.username, ts.mb.username) diff --git a/pkg/proxy/net/compress.go b/pkg/proxy/net/compress.go new file mode 100644 index 00000000..5e6b8d63 --- /dev/null +++ b/pkg/proxy/net/compress.go @@ -0,0 +1,307 @@ +// Copyright 2023 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package net + +import ( + "bytes" + "compress/zlib" + "io" + + "github.com/klauspost/compress/zstd" + "github.com/pingcap/tiproxy/lib/util/errors" + "go.uber.org/zap" +) + +// CompressAlgorithm is the algorithm for MySQL compressed protocol. +type CompressAlgorithm int + +const ( + // CompressionNone indicates no compression in use. + CompressionNone CompressAlgorithm = iota + // CompressionZlib is zlib/deflate. + CompressionZlib + // CompressionZstd is Facebook's Zstandard. + CompressionZstd +) + +type rwStatus int + +const ( + rwNone rwStatus = iota + rwRead + rwWrite +) + +const ( + // maxCompressedSize is the max uncompressed data size for a compressed packet. + // Packets bigger than maxCompressedSize will be split into multiple compressed packets. + // MySQL has 16K for the first packet. The rest packets and MySQL Connector/J are 16M. + // Two restrictions for the length: + // - it should be smaller than 16M so that the length can fit in the 3 byte field in the header. + // - it should be larger than 4M so that the compressed sequence can fit in the 3 byte field when max_allowed_packet is 1G. + maxCompressedSize = 1<<24 - 1 + // minCompressSize is the min uncompressed data size for compressed data. + // Packets smaller than minCompressSize won't be compressed. + // MySQL and MySQL Connector/J are both 50. + minCompressSize = 50 + // defaultZlibLevel is the compression level for zlib. MySQL is 6. + zlibCompressionLevel = 6 +) + +func (p *PacketIO) SetCompressionAlgorithm(algorithm CompressAlgorithm, zstdLevel int) error { + switch algorithm { + case CompressionZlib, CompressionZstd: + p.readWriter = newCompressedReadWriter(p.readWriter, algorithm, zstdLevel, p.logger) + case CompressionNone: + default: + return errors.Errorf("Unknown compression algorithm %d", algorithm) + } + return nil +} + +var _ packetReadWriter = (*compressedReadWriter)(nil) + +type compressedReadWriter struct { + packetReadWriter + readBuffer bytes.Buffer + writeBuffer bytes.Buffer + algorithm CompressAlgorithm + logger *zap.Logger + rwStatus rwStatus + zstdLevel zstd.EncoderLevel + sequence uint8 +} + +func newCompressedReadWriter(rw packetReadWriter, algorithm CompressAlgorithm, zstdLevel int, logger *zap.Logger) *compressedReadWriter { + return &compressedReadWriter{ + packetReadWriter: rw, + algorithm: algorithm, + zstdLevel: zstd.EncoderLevelFromZstd(zstdLevel), + logger: logger, + rwStatus: rwNone, + } +} + +func (crw *compressedReadWriter) SetSequence(sequence uint8) { + crw.packetReadWriter.SetSequence(sequence) + // Reset the compressed sequence before the next command. + if sequence == 0 { + crw.sequence = 0 + crw.rwStatus = rwNone + } +} + +// Uncompressed sequence of MySQL doesn't follow the spec: it's set to the compressed sequence when +// the client/server begins reading or writing. +func (crw *compressedReadWriter) beginRW(status rwStatus) { + if crw.rwStatus != status { + crw.packetReadWriter.SetSequence(crw.sequence) + crw.rwStatus = status + } +} + +func (crw *compressedReadWriter) Read(p []byte) (n int, err error) { + crw.beginRW(rwRead) + // Read from the connection to fill the buffer if the buffer is empty. + if crw.readBuffer.Len() == 0 { + if err = crw.readFromConn(); err != nil { + return + } + } + n, err = crw.readBuffer.Read(p) + // Trade off between memory and efficiency. + if n == len(p) && crw.readBuffer.Len() == 0 && crw.readBuffer.Cap() > defaultReaderSize { + crw.readBuffer = bytes.Buffer{} + } + return +} + +// Read and uncompress the data into readBuffer. +// The format of the protocol: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_compression_packet.html +func (crw *compressedReadWriter) readFromConn() error { + var err error + var header [7]byte + if _, err = io.ReadFull(crw.packetReadWriter, header[:]); err != nil { + return err + } + compressedSequence := header[3] + if compressedSequence != crw.sequence { + return ErrInvalidSequence.GenWithStack( + "invalid compressed sequence, expected %d, actual %d", crw.sequence, compressedSequence) + } + crw.sequence++ + compressedLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) + + if uncompressedLength == 0 { + // If the data is uncompressed, the uncompressed length is 0 and compressed length is the data length + // after the compressed header. + crw.readBuffer.Grow(compressedLength) + if _, err = io.CopyN(&crw.readBuffer, crw.packetReadWriter, int64(compressedLength)); err != nil { + return errors.WithStack(err) + } + } else { + // If the data is compressed, the compressed length is the length of data after the compressed header and + // the uncompressed length is the length of data after decompression. + data := make([]byte, compressedLength) + if _, err = io.ReadFull(crw.packetReadWriter, data); err != nil { + return err + } + if err = crw.uncompress(data, uncompressedLength); err != nil { + return err + } + } + return nil +} + +func (crw *compressedReadWriter) Write(data []byte) (n int, err error) { + crw.beginRW(rwWrite) + for { + remainingLen := maxCompressedSize - crw.writeBuffer.Len() + if len(data) <= remainingLen { + written, err := crw.writeBuffer.Write(data) + if err != nil { + return n, err + } + return n + written, nil + } + written, err := crw.writeBuffer.Write(data[:remainingLen]) + if err != nil { + return n, err + } + n += written + data = data[remainingLen:] + if err = crw.Flush(); err != nil { + return n, err + } + } +} + +func (crw *compressedReadWriter) Flush() error { + var err error + data := crw.writeBuffer.Bytes() + if len(data) == 0 { + return nil + } + // Trade off between memory and efficiency. + if crw.writeBuffer.Cap() > defaultWriterSize { + crw.writeBuffer = bytes.Buffer{} + } else { + crw.writeBuffer.Reset() + } + + // If the data is uncompressed, the uncompressed length is 0 and compressed length is the data length + // after the compressed header. + uncompressedLength := 0 + compressedLength := len(data) + if len(data) >= minCompressSize { + // If the data is compressed, the compressed length is the length of data after the compressed header and + // the uncompressed length is the length of data after decompression. + uncompressedLength = len(data) + if data, err = crw.compress(data); err != nil { + return err + } + compressedLength = len(data) + } + + var compressedHeader [7]byte + compressedHeader[0] = byte(compressedLength) + compressedHeader[1] = byte(compressedLength >> 8) + compressedHeader[2] = byte(compressedLength >> 16) + compressedHeader[3] = crw.sequence + compressedHeader[4] = byte(uncompressedLength) + compressedHeader[5] = byte(uncompressedLength >> 8) + compressedHeader[6] = byte(uncompressedLength >> 16) + crw.sequence++ + if _, err = crw.packetReadWriter.Write(compressedHeader[:]); err != nil { + return errors.WithStack(err) + } + if _, err = crw.packetReadWriter.Write(data); err != nil { + return errors.WithStack(err) + } + return crw.packetReadWriter.Flush() +} + +// DirectWrite won't be used. +func (crw *compressedReadWriter) DirectWrite(data []byte) (n int, err error) { + if n, err = crw.Write(data); err != nil { + return + } + return n, crw.Flush() +} + +// Peek won't be used. +// Notice: the peeked data may be discarded if an error is returned. +func (crw *compressedReadWriter) Peek(n int) (data []byte, err error) { + crw.beginRW(rwRead) + for crw.readBuffer.Len() < n { + if err = crw.readFromConn(); err != nil { + return + } + } + data = make([]byte, 0, n) + copy(data, crw.readBuffer.Bytes()) + return +} + +// Discard won't be used. +func (crw *compressedReadWriter) Discard(n int) (d int, err error) { + crw.beginRW(rwRead) + for crw.readBuffer.Len() < n { + if err = crw.readFromConn(); err != nil { + return + } + } + crw.readBuffer.Next(n) + return n, err +} + +// DataDog/zstd is much faster but it's not good at cross-platform. +// https://github.com/go-mysql-org/go-mysql/issues/799 +func (crw *compressedReadWriter) compress(data []byte) ([]byte, error) { + var err error + var compressedPacket bytes.Buffer + var compressWriter io.WriteCloser + switch crw.algorithm { + case CompressionZlib: + compressWriter, err = zlib.NewWriterLevel(&compressedPacket, zlibCompressionLevel) + case CompressionZstd: + compressWriter, err = zstd.NewWriter(&compressedPacket, zstd.WithEncoderLevel(crw.zstdLevel)) + } + if err != nil { + return nil, errors.WithStack(err) + } + if _, err = compressWriter.Write(data); err != nil { + return nil, errors.WithStack(err) + } + if err = compressWriter.Close(); err != nil { + return nil, errors.WithStack(err) + } + return compressedPacket.Bytes(), nil +} + +func (crw *compressedReadWriter) uncompress(data []byte, uncompressedLength int) error { + var err error + var compressedReader io.ReadCloser + switch crw.algorithm { + case CompressionZlib: + if compressedReader, err = zlib.NewReader(bytes.NewReader(data)); err != nil { + return errors.WithStack(err) + } + case CompressionZstd: + var decoder *zstd.Decoder + if decoder, err = zstd.NewReader(bytes.NewReader(data)); err != nil { + return errors.WithStack(err) + } + compressedReader = decoder.IOReadCloser() + } + crw.readBuffer.Grow(uncompressedLength) + if _, err = io.CopyN(&crw.readBuffer, compressedReader, int64(uncompressedLength)); err != nil { + return errors.WithStack(err) + } + if err = compressedReader.Close(); err != nil { + return errors.WithStack(err) + } + return nil +} diff --git a/pkg/proxy/net/compress_test.go b/pkg/proxy/net/compress_test.go new file mode 100644 index 00000000..48a54141 --- /dev/null +++ b/pkg/proxy/net/compress_test.go @@ -0,0 +1,275 @@ +// Copyright 2023 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package net + +import ( + "fmt" + "io" + "math/rand" + "net" + "testing" + + "github.com/pingcap/tiproxy/lib/util/logger" + "github.com/pingcap/tiproxy/pkg/testkit" + "github.com/stretchr/testify/require" +) + +// Test read/write with zlib compression. +func TestCompressZlib(t *testing.T) { + sizes := []int{minCompressSize - 1, 1024, maxCompressedSize, maxCompressedSize + 1, maxCompressedSize * 2} + lg, _ := logger.CreateLoggerForTest(t) + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + written := crw.OutBytes() + for _, size := range sizes { + fillAndWrite(t, crw, 'a', size) + require.NoError(t, crw.Flush()) + // Check compressed bytes. + outBytes := crw.OutBytes() + checkWrittenByteSize(t, outBytes-written, size) + written = outBytes + } + }, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + for _, size := range sizes { + readAndCheck(t, crw, 'a', size) + } + }, 1) +} + +// Test read/write with zstd compression. +func TestCompressZstd(t *testing.T) { + sizes := []int{minCompressSize - 1, 1024, maxCompressedSize, maxCompressedSize + 1, maxCompressedSize * 2} + levels := []int{1, 3, 9, 22} + lg, _ := logger.CreateLoggerForTest(t) + for _, level := range levels { + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZstd, level, lg) + written := crw.OutBytes() + for _, size := range sizes { + fillAndWrite(t, crw, 'a', size) + require.NoError(t, crw.Flush()) + // Check compressed bytes. + outBytes := crw.OutBytes() + checkWrittenByteSize(t, outBytes-written, size) + written = outBytes + } + }, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZstd, level, lg) + for _, size := range sizes { + readAndCheck(t, crw, 'a', size) + } + }, 1) + } +} + +// Test that multiple packets are merged into one compressed packet. +func TestCompressMergePkt(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + sizes := make([]int, 50) + for i := range sizes { + sizes[i] = int(rand.Int31n(maxCompressedSize / 2)) + } + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + written := 0 + for i, size := range sizes { + fillAndWrite(t, crw, 'a'+byte(i), size) + // Check that data is buffered until reaching maxCompressedSize. + written += size + require.Equal(t, written%maxCompressedSize, crw.writeBuffer.Len()) + } + require.NoError(t, crw.Flush()) + }, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + for i, size := range sizes { + readAndCheck(t, crw, 'a'+byte(i), size) + } + }, 1) +} + +// Test that DirectWrite(), Peek(), and Discard() work well. +func TestCompressPeekDiscard(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + size := 1000 + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + data := fillData('a', size) + _, err := crw.DirectWrite(data) + require.NoError(t, err) + + data = fillData('b', size) + _, err = crw.DirectWrite(data) + require.NoError(t, err) + }, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + peek, err := crw.Peek(10) + require.NoError(t, err) + checkData(t, peek, 'a') + readAndCheck(t, crw, 'a', size) + + _, err = crw.Discard(100) + require.NoError(t, err) + readAndCheck(t, crw, 'b', size-100) + }, 1) +} + +// Test that the uncompressed sequence is correct. +func TestCompressSequence(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + fillAndWrite(t, crw, 'a', 100) + fillAndWrite(t, crw, 'a', 100) + require.NoError(t, crw.Flush()) + require.Equal(t, uint8(2), crw.Sequence()) + // uncompressed sequence = compressed sequence + readAndCheck(t, crw, 'a', 100) + require.Equal(t, uint8(2), crw.Sequence()) + readAndCheck(t, crw, 'a', 100) + require.Equal(t, uint8(3), crw.Sequence()) + // uncompressed sequence = compressed sequence + fillAndWrite(t, crw, 'a', maxCompressedSize+1) + require.NoError(t, crw.Flush()) + require.Equal(t, uint8(3), crw.Sequence()) + // uncompressed sequence = compressed sequence + readAndCheck(t, crw, 'a', maxCompressedSize+1) + require.Equal(t, uint8(5), crw.Sequence()) + // flush empty buffer won't increase sequence + require.NoError(t, crw.Flush()) + require.NoError(t, crw.Flush()) + fillAndWrite(t, crw, 'a', 100) + require.NoError(t, crw.Flush()) + }, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + readAndCheck(t, crw, 'a', 100) + readAndCheck(t, crw, 'a', 100) + require.Equal(t, uint8(2), crw.Sequence()) + // uncompressed sequence = compressed sequence + fillAndWrite(t, crw, 'a', 100) + require.Equal(t, uint8(2), crw.Sequence()) + fillAndWrite(t, crw, 'a', 100) + require.Equal(t, uint8(3), crw.Sequence()) + require.NoError(t, crw.Flush()) + // uncompressed sequence = compressed sequence + readAndCheck(t, crw, 'a', maxCompressedSize+1) + require.Equal(t, uint8(3), crw.Sequence()) + // uncompressed sequence = compressed sequence + fillAndWrite(t, crw, 'a', maxCompressedSize+1) + require.NoError(t, crw.Flush()) + require.Equal(t, uint8(5), crw.Sequence()) + // flush empty buffer won't increase sequence + readAndCheck(t, crw, 'a', 100) + }, 1) +} + +// Test that the compressed header is correctly filled. +func TestCompressHeader(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + sizes := []int{minCompressSize - 1, maxCompressedSize, maxCompressedSize + 1} + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + for i, size := range sizes { + fillAndWrite(t, crw, 'a'+byte(i), size) + require.NoError(t, crw.Flush()) + } + }, + func(t *testing.T, c net.Conn) { + brw := newBasicReadWriter(c) + crw := newCompressedReadWriter(brw, CompressionZlib, 0, lg) + for i, size := range sizes { + header, err := brw.Peek(7) + require.NoError(t, err) + compressedLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) + if size < minCompressSize { + require.Equal(t, size, compressedLength) + require.Equal(t, 0, uncompressedLength) + } else if size <= maxCompressedSize { + require.Greater(t, compressedLength, 0) + require.Less(t, compressedLength, size) + require.Equal(t, size, uncompressedLength) + } else { + require.Greater(t, compressedLength, 0) + require.Less(t, compressedLength, maxCompressedSize) + require.Equal(t, maxCompressedSize, uncompressedLength) + } + readAndCheck(t, crw, 'a'+byte(i), size) + } + }, 1) +} + +// Test that Read and Write returns correct errors. +func TestReadWriteError(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + }, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + _, err := crw.Read(make([]byte, 1)) + require.True(t, IsDisconnectError(err)) + }, 1) + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + }, + func(t *testing.T, c net.Conn) { + require.NoError(t, c.Close()) + crw := newCompressedReadWriter(newBasicReadWriter(c), CompressionZlib, 0, lg) + _, err := crw.Write(make([]byte, 1)) + require.NoError(t, err) + require.ErrorIs(t, crw.Flush(), net.ErrClosed) + }, 1) +} + +func fillAndWrite(t *testing.T, crw *compressedReadWriter, b byte, length int) { + data := fillData(b, length) + _, err := crw.Write(data) + require.NoError(t, err) + crw.SetSequence(crw.Sequence() + 1) +} + +func fillData(b byte, length int) []byte { + data := make([]byte, length) + for i := range data { + data[i] = b + } + return data +} + +func readAndCheck(t *testing.T, crw *compressedReadWriter, b byte, length int) { + data := make([]byte, length) + _, err := io.ReadFull(crw, data) + require.NoError(t, err) + checkData(t, data, b) + crw.SetSequence(crw.Sequence() + 1) +} + +func checkData(t *testing.T, data []byte, b byte) { + for i := range data { + if data[i] != b { + require.Fail(t, fmt.Sprintf("expected %c, but got %c", b, data[i])) + } + } +} + +func checkWrittenByteSize(t *testing.T, diff uint64, size int) { + if size < minCompressSize { + require.Equal(t, uint64(size+7), diff) + } else { + require.Greater(t, diff, uint64(0)) + require.Less(t, diff, uint64(size+7)) + } +} diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index c128bada..1cf0fc83 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -62,6 +62,7 @@ type HandshakeResp struct { AuthPlugin string AuthData []byte Capability Capability + ZstdLevel int Collation uint8 } @@ -132,11 +133,19 @@ func ParseHandshakeResponse(data []byte) (*HandshakeResp, error) { pos += off row := data[pos : pos+int(num)] resp.Attrs, err = parseAttrs(row) + // Some clients have known bugs, but we should be compatible with them. + // E.g. https://bugs.mysql.com/bug.php?id=79612. if err != nil { err = &errors.Warning{Err: errors.Wrapf(err, "parse attrs failed")} } + pos += int(num) } } + + // zstd compress level + if resp.Capability&ClientZstdCompressionAlgorithm > 0 { + resp.ZstdLevel = int(data[pos]) + } return resp, err } @@ -192,7 +201,7 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { attrBuf = DumpLengthEncodedInt(attrLenBuf[:0], uint64(len(attrs))) } - length := 4 + 4 + 1 + 23 + len(resp.User) + 1 + len(authResp) + len(resp.AuthData) + len(resp.DB) + 1 + len(resp.AuthPlugin) + 1 + len(attrBuf) + len(attrs) + length := 4 + 4 + 1 + 23 + len(resp.User) + 1 + len(authResp) + len(resp.AuthData) + len(resp.DB) + 1 + len(resp.AuthPlugin) + 1 + len(attrBuf) + len(attrs) + 1 data := make([]byte, length) pos := 0 // capability [32 bit] @@ -244,6 +253,12 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { pos += copy(data[pos:], attrBuf) pos += copy(data[pos:], attrs) } + + // compress level + if capability&ClientZstdCompressionAlgorithm > 0 { + data[pos] = byte(resp.ZstdLevel) + pos++ + } return data[:pos] } diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 5b151698..2db79bf3 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -30,7 +30,6 @@ import ( "crypto/tls" "io" "net" - "sync/atomic" "time" "github.com/pingcap/tidb/errno" @@ -51,51 +50,120 @@ const ( defaultReaderSize = 16 * 1024 ) -// rdbufConn will buffer read for non-TLS connections. -// While TLS connections have internal buffering, we still need to pass *rdbufConn to `tls.XXX()`. -// Because TLS handshake data may already be buffered in `*rdbufConn`. -// TODO: only enable writer buffering for TLS connections, otherwise enable read/write buffering. -type rdbufConn struct { +// packetReadWriter acts like a net.Conn with read and write buffer. +type packetReadWriter interface { net.Conn - *bufio.Reader + // Peek / Discard / Flush are implemented by bufio.ReadWriter. + Peek(n int) ([]byte, error) + Discard(n int) (int, error) + Flush() error + DirectWrite(p []byte) (int, error) + Proxy() *proxyprotocol.Proxy + TLSConnectionState() tls.ConnectionState + InBytes() uint64 + OutBytes() uint64 + IsPeerActive() bool + SetSequence(uint8) + Sequence() uint8 } -func (f *rdbufConn) Read(b []byte) (int, error) { - return f.Reader.Read(b) +var _ packetReadWriter = (*basicReadWriter)(nil) + +// basicReadWriter is used for raw connections. +type basicReadWriter struct { + net.Conn + *bufio.ReadWriter + inBytes uint64 + outBytes uint64 + sequence uint8 +} + +func newBasicReadWriter(conn net.Conn) *basicReadWriter { + return &basicReadWriter{ + Conn: conn, + ReadWriter: bufio.NewReadWriter(bufio.NewReaderSize(conn, defaultReaderSize), bufio.NewWriterSize(conn, defaultWriterSize)), + } +} + +func (brw *basicReadWriter) Read(b []byte) (n int, err error) { + n, err = brw.ReadWriter.Read(b) + brw.inBytes += uint64(n) + return n, errors.WithStack(err) +} + +func (brw *basicReadWriter) Write(p []byte) (int, error) { + n, err := brw.ReadWriter.Write(p) + brw.outBytes += uint64(n) + return n, errors.WithStack(err) +} + +func (brw *basicReadWriter) DirectWrite(p []byte) (int, error) { + n, err := brw.Conn.Write(p) + brw.outBytes += uint64(n) + return n, errors.WithStack(err) +} + +func (brw *basicReadWriter) SetSequence(sequence uint8) { + brw.sequence = sequence +} + +func (brw *basicReadWriter) Sequence() uint8 { + return brw.sequence +} + +func (brw *basicReadWriter) Proxy() *proxyprotocol.Proxy { + return nil +} + +func (brw *basicReadWriter) InBytes() uint64 { + return brw.inBytes +} + +func (brw *basicReadWriter) OutBytes() uint64 { + return brw.outBytes +} + +func (brw *basicReadWriter) TLSConnectionState() tls.ConnectionState { + return tls.ConnectionState{} +} + +// IsPeerActive checks if the peer connection is still active. +// If the backend disconnects, the client should also be disconnected (required by serverless). +// We have no other way than reading from the connection. +// +// This function cannot be called concurrently with other functions of packetReadWriter. +// This function normally costs 1ms, so don't call it too frequently. +// This function may incorrectly return true if the system is extremely slow. +func (brw *basicReadWriter) IsPeerActive() bool { + if err := brw.Conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { + return false + } + active := true + if _, err := brw.ReadWriter.Peek(1); err != nil { + active = !errors.Is(err, io.EOF) + } + if err := brw.Conn.SetReadDeadline(time.Time{}); err != nil { + return false + } + return active } // PacketIO is a helper to read and write sql and proxy protocol. type PacketIO struct { lastKeepAlive config.KeepAlive - inBytes uint64 - outBytes uint64 - conn net.Conn rawConn net.Conn - buf *bufio.ReadWriter - proxyInited atomic.Bool - proxy *proxyprotocol.Proxy + readWriter packetReadWriter logger *zap.Logger remoteAddr net.Addr wrap error - sequence uint8 } func NewPacketIO(conn net.Conn, lg *zap.Logger, opts ...PacketIOption) *PacketIO { - buf := bufio.NewReadWriter( - bufio.NewReaderSize(conn, defaultReaderSize), - bufio.NewWriterSize(conn, defaultWriterSize), - ) p := &PacketIO{ - rawConn: conn, - conn: &rdbufConn{ - conn, - buf.Reader, - }, - logger: lg, - sequence: 0, - buf: buf, + rawConn: conn, + logger: lg, + readWriter: newBasicReadWriter(conn), } - p.proxyInited.Store(true) p.ApplyOpts(opts...) return p } @@ -110,75 +178,42 @@ func (p *PacketIO) wrapErr(err error) error { return errors.WithStack(errors.Wrap(p.wrap, err)) } -// Proxy returned parsed proxy header from clients if any. -func (p *PacketIO) Proxy() *proxyprotocol.Proxy { - return p.proxy -} - func (p *PacketIO) LocalAddr() net.Addr { - return p.conn.LocalAddr() + return p.readWriter.LocalAddr() } func (p *PacketIO) RemoteAddr() net.Addr { if p.remoteAddr != nil { return p.remoteAddr } - return p.conn.RemoteAddr() + return p.readWriter.RemoteAddr() } func (p *PacketIO) ResetSequence() { - p.sequence = 0 + p.readWriter.SetSequence(0) } // GetSequence is used in tests to assert that the sequences on the client and server are equal. func (p *PacketIO) GetSequence() uint8 { - return p.sequence + return p.readWriter.Sequence() } func (p *PacketIO) readOnePacket() ([]byte, bool, error) { var header [4]byte - - if _, err := io.ReadFull(p.buf, header[:]); err != nil { + if _, err := io.ReadFull(p.readWriter, header[:]); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } - p.inBytes += 4 - - // probe proxy V2 - refill := false - if !p.proxyInited.Load() { - if bytes.Equal(header[:], proxyprotocol.MagicV2[:4]) { - proxyHeader, err := p.parseProxyV2() - if err != nil { - return nil, false, errors.Wrap(ErrReadConn, err) - } - if proxyHeader != nil { - p.proxy = proxyHeader - refill = true - } - } - p.proxyInited.Store(true) - } - - // refill mysql headers - if refill { - if _, err := io.ReadFull(p.buf, header[:]); err != nil { - return nil, false, errors.Wrap(ErrReadConn, err) - } - p.inBytes += 4 + sequence, pktSequence := header[3], p.readWriter.Sequence() + if sequence != pktSequence { + return nil, false, ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence) } + p.readWriter.SetSequence(sequence + 1) - sequence := header[3] - if sequence != p.sequence { - return nil, false, ErrInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence) - } - p.sequence++ length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - data := make([]byte, length) - if _, err := io.ReadFull(p.buf, data); err != nil { + if _, err := io.ReadFull(p.readWriter, data); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } - p.inBytes += uint64(length) return data, length == MaxPayloadLen, nil } @@ -207,21 +242,20 @@ func (p *PacketIO) writeOnePacket(data []byte) (int, bool, error) { } var header [4]byte + sequence := p.readWriter.Sequence() header[0] = byte(length) header[1] = byte(length >> 8) header[2] = byte(length >> 16) - header[3] = p.sequence - p.sequence++ + header[3] = sequence + p.readWriter.SetSequence(sequence + 1) - if _, err := io.Copy(p.buf, bytes.NewReader(header[:])); err != nil { + if _, err := io.Copy(p.readWriter, bytes.NewReader(header[:])); err != nil { return 0, more, errors.Wrap(ErrWriteConn, err) } - p.outBytes += 4 - if _, err := io.Copy(p.buf, bytes.NewReader(data[:length])); err != nil { + if _, err := io.Copy(p.readWriter, bytes.NewReader(data[:length])); err != nil { return 0, more, errors.Wrap(ErrWriteConn, err) } - p.outBytes += uint64(length) return length, more, nil } @@ -244,43 +278,22 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) { } func (p *PacketIO) InBytes() uint64 { - return p.inBytes + return p.readWriter.InBytes() } func (p *PacketIO) OutBytes() uint64 { - return p.outBytes -} - -func (p *PacketIO) TLSConnectionState() tls.ConnectionState { - if tlsConn, ok := p.conn.(*tls.Conn); ok { - return tlsConn.ConnectionState() - } - return tls.ConnectionState{} + return p.readWriter.OutBytes() } func (p *PacketIO) Flush() error { - if err := p.buf.Flush(); err != nil { + if err := p.readWriter.Flush(); err != nil { return p.wrapErr(errors.Wrap(ErrFlushConn, err)) } return nil } -// IsPeerActive checks if the peer connection is still active. -// This function cannot be called concurrently with other functions of PacketIO. -// This function normally costs 1ms, so don't call it too frequently. -// This function may incorrectly return true if the system is extremely slow. func (p *PacketIO) IsPeerActive() bool { - if err := p.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { - return false - } - active := true - if _, err := p.buf.Peek(1); err != nil { - active = !errors.Is(err, io.EOF) - } - if err := p.conn.SetReadDeadline(time.Time{}); err != nil { - return false - } - return active + return p.readWriter.IsPeerActive() } func (p *PacketIO) SetKeepalive(cfg config.KeepAlive) error { @@ -297,7 +310,7 @@ func (p *PacketIO) LastKeepAlive() config.KeepAlive { } func (p *PacketIO) GracefulClose() error { - if err := p.conn.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) { + if err := p.readWriter.SetDeadline(time.Now()); err != nil && !errors.Is(err, net.ErrClosed) { return err } return nil @@ -311,7 +324,7 @@ func (p *PacketIO) Close() error { errs = append(errs, err) } */ - if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + if err := p.readWriter.Close(); err != nil && !errors.Is(err, net.ErrClosed) { errs = append(errs, err) } return p.wrapErr(errors.Collect(ErrCloseConn, errs...)) diff --git a/pkg/proxy/net/packetio_options.go b/pkg/proxy/net/packetio_options.go index 83a0def0..0a815642 100644 --- a/pkg/proxy/net/packetio_options.go +++ b/pkg/proxy/net/packetio_options.go @@ -12,7 +12,7 @@ import ( type PacketIOption = func(*PacketIO) func WithProxy(pi *PacketIO) { - pi.proxyInited.Store(false) + pi.EnableProxyServer() } func WithWrapError(err error) func(pi *PacketIO) { diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 14e0bdfc..e700a681 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -287,3 +287,78 @@ func TestPredefinedPacket(t *testing.T) { 1, ) } + +// Test the combination of proxy, tls and compress. +func TestProxyTLSCompress(t *testing.T) { + stls, ctls, err := security.CreateTLSConfigForTest() + require.NoError(t, err) + addr, p := mockProxy(t) + ch := make(chan []byte) + write := func(p *PacketIO, data []byte) { + outBytes := p.OutBytes() + require.NoError(t, p.WritePacket(data, true)) + ch <- data + require.Greater(t, p.OutBytes(), outBytes) + require.True(t, p.IsPeerActive()) + require.NotEmpty(t, p.RemoteAddr().String()) + } + read := func(p *PacketIO) { + inBytes := p.InBytes() + data := <-ch + pkt, err := p.ReadPacket() + require.NoError(t, err) + require.Equal(t, data, pkt) + require.Greater(t, p.InBytes(), inBytes) + require.True(t, p.IsPeerActive()) + require.NotEmpty(t, p.RemoteAddr().String()) + } + for _, enableCompress := range []bool{true, false} { + for _, enableTLS := range []bool{true, false} { + for _, enableProxy := range []bool{true, false} { + testTCPConn(t, func(t *testing.T, cli *PacketIO) { + if enableProxy { + cli.EnableProxyClient(p) + } + write(cli, []byte("test1")) + if enableTLS { + require.NoError(t, cli.ClientTLSHandshake(ctls)) + require.True(t, cli.TLSConnectionState().HandshakeComplete) + } + read(cli) + if enableCompress { + cli.ResetSequence() + require.NoError(t, cli.SetCompressionAlgorithm(CompressionZlib, 0)) + } + write(cli, []byte("test3")) + read(cli) + // make sure the peer won't quit in advance + ch <- nil + }, func(t *testing.T, srv *PacketIO) { + if enableProxy { + srv.EnableProxyServer() + } + read(srv) + if enableProxy { + require.Equal(t, addr.String(), srv.RemoteAddr().String()) + require.Equal(t, addr.String(), srv.Proxy().SrcAddress.String()) + } + if enableTLS { + state, err := srv.ServerTLSHandshake(stls) + require.NoError(t, err) + require.True(t, state.HandshakeComplete) + require.True(t, srv.TLSConnectionState().HandshakeComplete) + } + write(srv, []byte("test2")) + if enableCompress { + srv.ResetSequence() + require.NoError(t, srv.SetCompressionAlgorithm(CompressionZlib, 0)) + } + read(srv) + write(srv, []byte("test4")) + // make sure the peer won't quit in advance + <-ch + }, 1) + } + } + } +} diff --git a/pkg/proxy/net/proxy.go b/pkg/proxy/net/proxy.go index c86d5a22..f4da3146 100644 --- a/pkg/proxy/net/proxy.go +++ b/pkg/proxy/net/proxy.go @@ -6,46 +6,125 @@ package net import ( "bytes" "io" + "net" + "sync/atomic" "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/pkg/proxy/proxyprotocol" ) -func (p *PacketIO) parseProxyV2() (*proxyprotocol.Proxy, error) { - rem, err := p.buf.Peek(8) +func (p *PacketIO) EnableProxyClient(proxy *proxyprotocol.Proxy) { + p.readWriter = newProxyClient(p.readWriter, proxy) +} + +func (p *PacketIO) EnableProxyServer() { + p.readWriter = newProxyServer(p.readWriter) +} + +// Proxy returned parsed proxy header from clients if any. +func (p *PacketIO) Proxy() *proxyprotocol.Proxy { + return p.readWriter.Proxy() +} + +var _ packetReadWriter = (*proxyReadWriter)(nil) + +type proxyReadWriter struct { + packetReadWriter + proxyInited atomic.Bool + proxy *proxyprotocol.Proxy + addr net.Addr + client bool +} + +func newProxyClient(rw packetReadWriter, proxy *proxyprotocol.Proxy) *proxyReadWriter { + prw := &proxyReadWriter{ + packetReadWriter: rw, + proxy: proxy, + client: true, + } + return prw +} + +func newProxyServer(rw packetReadWriter) *proxyReadWriter { + prw := &proxyReadWriter{ + packetReadWriter: rw, + client: false, + } + return prw +} + +func (prw *proxyReadWriter) Read(b []byte) (int, error) { + // probe proxy V2 + if !prw.client && !prw.proxyInited.Load() { + // We don't know whether the client has enabled proxy protocol. + // If it doesn't, reading data of len(MagicV2) may block forever. + header, err := prw.Peek(4) + if err != nil { + return 0, errors.Wrap(ErrReadConn, err) + } + if bytes.Equal(header[:], proxyprotocol.MagicV2[:4]) { + proxyHeader, err := prw.parseProxyV2() + if err != nil { + return 0, errors.Wrap(ErrReadConn, err) + } + if proxyHeader != nil { + prw.proxy = proxyHeader + } + } + prw.proxyInited.Store(true) + } + return prw.packetReadWriter.Read(b) +} + +func (prw *proxyReadWriter) Write(p []byte) (n int, err error) { + // The proxy header should be written at the beginning of connection, before any write operations. + if prw.client && !prw.proxyInited.Load() { + buf, err := prw.proxy.ToBytes() + if err != nil { + return 0, errors.Wrap(ErrWriteConn, err) + } + if _, err := io.Copy(prw.packetReadWriter, bytes.NewReader(buf)); err != nil { + return 0, errors.Wrap(ErrWriteConn, err) + } + // according to the spec, we better flush to avoid server hanging + if err := prw.packetReadWriter.Flush(); err != nil { + return 0, err + } + prw.proxyInited.Store(true) + } + return prw.packetReadWriter.Write(p) +} + +func (prw *proxyReadWriter) parseProxyV2() (*proxyprotocol.Proxy, error) { + rem, err := prw.packetReadWriter.Peek(len(proxyprotocol.MagicV2)) if err != nil { return nil, errors.WithStack(errors.Wrap(ErrReadConn, err)) } - if !bytes.Equal(rem, proxyprotocol.MagicV2[4:]) { + if !bytes.Equal(rem, proxyprotocol.MagicV2) { return nil, nil } // yes, it is proxyV2 - _, err = p.buf.Discard(8) + _, err = prw.packetReadWriter.Discard(len(proxyprotocol.MagicV2)) if err != nil { return nil, errors.WithStack(errors.Wrap(ErrReadConn, err)) } - p.inBytes += 8 - m, n, err := proxyprotocol.ParseProxyV2(p.buf) - p.inBytes += uint64(n) + m, _, err := proxyprotocol.ParseProxyV2(prw.packetReadWriter) if err == nil { // set RemoteAddr in case of proxy. - p.remoteAddr = m.SrcAddress + prw.addr = m.SrcAddress } return m, err } -// WriteProxyV2 should only be called at the beginning of connection, before any write operations. -func (p *PacketIO) WriteProxyV2(m *proxyprotocol.Proxy) error { - buf, err := m.ToBytes() - if err != nil { - return errors.Wrap(ErrWriteConn, err) +func (prw *proxyReadWriter) RemoteAddr() net.Addr { + if prw.addr != nil { + return prw.addr } - if _, err := io.Copy(p.buf, bytes.NewReader(buf)); err != nil { - return errors.Wrap(ErrWriteConn, err) - } - p.outBytes += uint64(len(buf)) - // according to the spec, we better flush to avoid server hanging - return p.Flush() + return prw.packetReadWriter.RemoteAddr() +} + +func (prw *proxyReadWriter) Proxy() *proxyprotocol.Proxy { + return prw.proxy } diff --git a/pkg/proxy/net/proxy_test.go b/pkg/proxy/net/proxy_test.go index 15aaa545..da775025 100644 --- a/pkg/proxy/net/proxy_test.go +++ b/pkg/proxy/net/proxy_test.go @@ -10,34 +10,17 @@ import ( "testing" "github.com/pingcap/tiproxy/pkg/proxy/proxyprotocol" + "github.com/pingcap/tiproxy/pkg/testkit" "github.com/stretchr/testify/require" ) func TestProxyParse(t *testing.T) { - tcpaddr, err := net.ResolveTCPAddr("tcp", "192.168.1.1:34") - require.NoError(t, err) - + tcpaddr, p := mockProxy(t) testPipeConn(t, func(t *testing.T, cli *PacketIO) { - p := &proxyprotocol.Proxy{ - Version: proxyprotocol.ProxyVersion2, - Command: proxyprotocol.ProxyCommandLocal, - SrcAddress: tcpaddr, - DstAddress: tcpaddr, - TLV: []proxyprotocol.ProxyTlv{ - { - Typ: proxyprotocol.ProxyTlvALPN, - Content: nil, - }, - { - Typ: proxyprotocol.ProxyTlvUniqueID, - Content: []byte("test"), - }, - }, - } b, err := p.ToBytes() require.NoError(t, err) - _, err = io.Copy(cli.conn, bytes.NewReader(b)) + _, err = io.Copy(cli.readWriter, bytes.NewReader(b)) require.NoError(t, err) err = cli.WritePacket([]byte("hello"), true) require.NoError(t, err) @@ -52,3 +35,47 @@ func TestProxyParse(t *testing.T) { 1, ) } + +func TestProxyReadWrite(t *testing.T) { + addr, p := mockProxy(t) + message := []byte("hello world") + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + prw := newProxyClient(newBasicReadWriter(c), p) + n, err := prw.Write(message) + require.NoError(t, err) + require.Equal(t, len(message), n) + require.NoError(t, prw.Flush()) + }, + func(t *testing.T, c net.Conn) { + prw := newProxyServer(newBasicReadWriter(c)) + data := make([]byte, len(message)) + n, err := prw.Read(data) + require.NoError(t, err) + require.Equal(t, len(message), n) + require.Equal(t, p.SrcAddress, prw.Proxy().SrcAddress) + require.Equal(t, addr.String(), prw.RemoteAddr().String()) + }, 1) +} + +func mockProxy(t *testing.T) (*net.TCPAddr, *proxyprotocol.Proxy) { + tcpaddr, err := net.ResolveTCPAddr("tcp", "192.168.1.1:34") + require.NoError(t, err) + p := &proxyprotocol.Proxy{ + Version: proxyprotocol.ProxyVersion2, + Command: proxyprotocol.ProxyCommandLocal, + SrcAddress: tcpaddr, + DstAddress: tcpaddr, + TLV: []proxyprotocol.ProxyTlv{ + { + Typ: proxyprotocol.ProxyTlvALPN, + Content: nil, + }, + { + Typ: proxyprotocol.ProxyTlvUniqueID, + Content: []byte("test"), + }, + }, + } + return tcpaddr, p +} diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index 226280f5..2d5bdf91 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -10,28 +10,88 @@ import ( "github.com/pingcap/tiproxy/lib/util/errors" ) +// tlsHandshakeConn is only used as the underlying connection in tls.Conn. +// TLS handshake must read from the buffered reader because the handshake data may be already buffered in the reader. +// TLS handshake can not use the buffered writer directly because it assumes the data will be flushed automatically, +// however buffered writer may not flush without calling `Flush`. +type tlsInternalConn struct { + packetReadWriter +} + +func (br *tlsInternalConn) Write(p []byte) (n int, err error) { + return br.packetReadWriter.DirectWrite(p) +} + func (p *PacketIO) ServerTLSHandshake(tlsConfig *tls.Config) (tls.ConnectionState, error) { tlsConfig = tlsConfig.Clone() - tlsConn := tls.Server(p.conn, tlsConfig) + conn := &tlsInternalConn{p.readWriter} + tlsConn := tls.Server(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { return tls.ConnectionState{}, p.wrapErr(errors.Wrap(ErrHandshakeTLS, err)) } - p.conn = tlsConn - p.buf.Writer.Reset(tlsConn) - // Wrap it with another buffer to enable Peek. - p.buf = bufio.NewReadWriter(bufio.NewReaderSize(tlsConn, defaultReaderSize), p.buf.Writer) + p.readWriter = newTLSReadWriter(p.readWriter, tlsConn) return tlsConn.ConnectionState(), nil } func (p *PacketIO) ClientTLSHandshake(tlsConfig *tls.Config) error { tlsConfig = tlsConfig.Clone() - tlsConn := tls.Client(p.conn, tlsConfig) + conn := &tlsInternalConn{p.readWriter} + tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err)) } - p.conn = tlsConn - p.buf.Writer.Reset(tlsConn) - // Wrap it with another buffer to enable Peek. - p.buf = bufio.NewReadWriter(bufio.NewReaderSize(tlsConn, defaultReaderSize), p.buf.Writer) + p.readWriter = newTLSReadWriter(p.readWriter, tlsConn) return nil } + +func (p *PacketIO) TLSConnectionState() tls.ConnectionState { + return p.readWriter.TLSConnectionState() +} + +var _ packetReadWriter = (*tlsReadWriter)(nil) + +type tlsReadWriter struct { + packetReadWriter + buf *bufio.ReadWriter + conn *tls.Conn +} + +func newTLSReadWriter(rw packetReadWriter, tlsConn *tls.Conn) *tlsReadWriter { + // Can not modify rw and reuse it because tlsConn is using rw internally. + // We must create another buffer. + buf := bufio.NewReadWriter(bufio.NewReaderSize(tlsConn, defaultReaderSize), bufio.NewWriterSize(tlsConn, defaultWriterSize)) + return &tlsReadWriter{ + packetReadWriter: rw, + buf: buf, + conn: tlsConn, + } +} + +func (trw *tlsReadWriter) Read(b []byte) (n int, err error) { + // inBytes and outBytes are updated internally in trw.packetReadWriter. + return trw.buf.Read(b) +} + +func (trw *tlsReadWriter) Write(p []byte) (int, error) { + return trw.buf.Write(p) +} + +func (trw *tlsReadWriter) DirectWrite(p []byte) (int, error) { + return trw.conn.Write(p) +} + +func (trw *tlsReadWriter) Peek(n int) ([]byte, error) { + return trw.buf.Peek(n) +} + +func (trw *tlsReadWriter) Discard(n int) (int, error) { + return trw.buf.Discard(n) +} + +func (trw *tlsReadWriter) Flush() error { + return trw.buf.Flush() +} + +func (trw *tlsReadWriter) TLSConnectionState() tls.ConnectionState { + return trw.conn.ConnectionState() +} diff --git a/pkg/proxy/net/tls_test.go b/pkg/proxy/net/tls_test.go new file mode 100644 index 00000000..91f4d192 --- /dev/null +++ b/pkg/proxy/net/tls_test.go @@ -0,0 +1,84 @@ +// Copyright 2023 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package net + +import ( + "crypto/tls" + "io" + "net" + "testing" + + "github.com/pingcap/tiproxy/lib/util/security" + "github.com/pingcap/tiproxy/pkg/testkit" + "github.com/stretchr/testify/require" +) + +func TestTLSReadWrite(t *testing.T) { + stls, ctls, err := security.CreateTLSConfigForTest() + require.NoError(t, err) + message := []byte("hello world") + ch := make(chan []byte) + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + brw := newBasicReadWriter(c) + conn := &tlsInternalConn{brw} + tlsConn := tls.Client(conn, ctls) + require.NoError(t, tlsConn.Handshake()) + trw := newTLSReadWriter(brw, tlsConn) + // check tls connection state + require.True(t, trw.TLSConnectionState().HandshakeComplete) + // check out bytes + outBytes := trw.OutBytes() + // Wait before writing, otherwise the message is buffered during TLS in the other goroutine. + ch <- message + n, err := trw.Write(message) + require.NoError(t, err) + require.NoError(t, trw.Flush()) + require.Equal(t, len(message), n) + require.Greater(t, trw.OutBytes(), outBytes+uint64(len(message))) + // check direct write + for i := 0; i < 2; i++ { + n, err = trw.DirectWrite(message) + require.NoError(t, err) + require.Equal(t, len(message), n) + } + }, + func(t *testing.T, c net.Conn) { + brw := newBasicReadWriter(c) + conn := &tlsInternalConn{brw} + tlsConn := tls.Server(conn, stls) + require.NoError(t, tlsConn.Handshake()) + trw := newTLSReadWriter(brw, tlsConn) + // check tls connection state + require.True(t, trw.TLSConnectionState().HandshakeComplete) + // check in bytes + inBytes := trw.InBytes() + message := <-ch + data := make([]byte, len(message)) + n, err := io.ReadFull(trw, data) + require.NoError(t, err) + require.Equal(t, len(message), n) + require.Equal(t, message, data) + require.Greater(t, trw.InBytes(), inBytes+uint64(len(message))) + // check peek + peek, err := trw.Peek(1) + require.NoError(t, err) + require.Len(t, peek, 1) + require.Equal(t, message[0], peek[0]) + data = make([]byte, len(message)) + n, err = io.ReadFull(trw, data) + require.NoError(t, err) + require.Equal(t, len(data), n) + require.Equal(t, message, data) + // check discard + n, err = trw.Discard(1) + require.NoError(t, err) + require.Equal(t, 1, n) + data = make([]byte, len(message)-1) + n, err = io.ReadFull(trw, data) + require.NoError(t, err) + require.Equal(t, len(data), n) + require.Equal(t, message[1:], data) + }, 1) +}