From b7afff823b21b079f131449732b0f2d5c348e79c Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Wed, 30 Nov 2022 10:54:57 +0800 Subject: [PATCH 1/2] add HandshakeHandler --- pkg/proxy/backend/authenticator.go | 22 +++++---- pkg/proxy/backend/authenticator_test.go | 27 ++++++++++- pkg/proxy/backend/backend_conn_mgr.go | 44 +++++++++-------- pkg/proxy/backend/backend_conn_mgr_test.go | 2 +- pkg/proxy/backend/handshake_handler.go | 55 +++++++++++++++++++++ pkg/proxy/backend/mock_backend_test.go | 2 +- pkg/proxy/backend/mock_client_test.go | 4 +- pkg/proxy/backend/mock_proxy_test.go | 30 +++++++++++- pkg/proxy/client/client_conn.go | 5 +- pkg/proxy/net/mysql.go | 56 ++++++++++++++++++---- pkg/proxy/net/mysql_test.go | 37 ++++++++++++++ pkg/proxy/net/packetio.go | 8 ++++ pkg/proxy/net/protocol.go | 22 ++++++++- 13 files changed, 268 insertions(+), 46 deletions(-) create mode 100644 pkg/proxy/backend/handshake_handler.go create mode 100644 pkg/proxy/net/mysql_test.go diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index a64be0c7..ef90d6f2 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -35,8 +35,9 @@ const unknownAuthPlugin = "auth_unknown_plugin" const requiredFrontendCaps = pnet.ClientProtocol41 const defRequiredBackendCaps = pnet.ClientDeprecateEOF -// Other server capabilities are not supported. ClientDeprecateEOF is supported but TiDB 6.2.0 doesn't support it now. -const supportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB | +// SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported. +// TiDB supports ClientDeprecateEOF since v6.3.0. +const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB | pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag | pnet.ClientSSL | pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements | pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData | @@ -49,7 +50,7 @@ type Authenticator struct { dbname string // default database name serverAddr string user string - attrs []byte // no need to parse + attrs map[string]string salt []byte capability uint32 // client capability collation uint8 @@ -72,7 +73,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO Version: pnet.ProxyVersion2, } } - // either from another proxy or directly from clients, we are actings as a proxy + // either from another proxy or directly from clients, we are acting as a proxy proxy.Command = pnet.ProxyCommandProxy if err := backendIO.WriteProxyV2(proxy); err != nil { return err @@ -82,7 +83,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO } func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error { - requiredBackendCaps := defRequiredBackendCaps + requiredBackendCaps := defRequiredBackendCaps & pnet.Capability(auth.capability) if auth.requireBackendTLS { requiredBackendCaps |= pnet.ClientSSL } @@ -97,7 +98,8 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili return nil } -func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, getBackendIO func(*Authenticator) (*pnet.PacketIO, error), frontendTLSConfig, backendTLSConfig *tls.Config) error { +func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, + getBackend backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() proxyCapability := auth.supportedServerCapabilities @@ -140,14 +142,18 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet if frontendCapability^commonCaps != 0 { logger.Debug("frontend send capabilities unsupported by proxy", zap.Stringer("common", commonCaps), zap.Stringer("frontend", frontendCapability^commonCaps), zap.Stringer("proxy", proxyCapability^commonCaps)) } - resp := pnet.ParseHandshakeResponse(pkt) auth.capability = commonCaps.Uint32() + + resp := pnet.ParseHandshakeResponse(pkt) + if err = handshakeHandler.HandleHandshakeResp(resp, clientIO.SourceAddr().String()); err != nil { + return err + } auth.user = resp.User auth.dbname = resp.DB auth.collation = resp.Collation auth.attrs = resp.Attrs - backendIO, err := getBackendIO(auth) + backendIO, err := getBackend(auth, resp) if err != nil { return err } diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 33164383..f00a9d77 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -15,6 +15,7 @@ package backend import ( + "net" "strings" "testing" @@ -162,7 +163,7 @@ func TestCapability(t *testing.T) { }, func(cfg *testConfig) { cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientConnectAtts - cfg.clientConfig.attrs = []byte(strings.Repeat("x", 512)) + cfg.clientConfig.attrs = map[string]string{"key": "value"} }, }, { @@ -207,3 +208,27 @@ func TestSecondHandshake(t *testing.T) { clean() } } + +func TestCustomAuth(t *testing.T) { + tc := newTCPConnSuite(t) + handler := &CustomHandshakeHandler{ + outUsername: "rewritten_user", + outAttrs: map[string]string{"key": "value"}, + } + ts, clean := newTestSuite(t, tc, func(cfg *testConfig) { + cfg.proxyConfig.handler = handler + }) + checker := func() { + require.Equal(t, ts.mc.username, handler.inUsername) + require.Equal(t, handler.outUsername, ts.mb.username) + require.Equal(t, handler.outAttrs, ts.mb.attrs) + host, _, err := net.SplitHostPort(handler.inAddr) + require.NoError(t, err) + require.Equal(t, host, "::1") + } + ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {}) + checker() + ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {}) + checker() + clean() +} diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 61bda683..15844e86 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -55,6 +55,8 @@ type redirectResult struct { to string } +type backendIOGetter func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) + // BackendConnManager migrates a session from one BackendConnection to another. // // The signal processing goroutine tries to migrate the session once it receives a signal. @@ -80,22 +82,25 @@ type BackendConnManager struct { // redirectResCh is used to notify the event receiver asynchronously. redirectResCh chan *redirectResult // cancelFunc is used to cancel the signal processing goroutine. - cancelFunc context.CancelFunc - backendConn *BackendConnection - nsmgr *namespace.NamespaceManager - getBackendIO func(*Authenticator) (*pnet.PacketIO, error) - connectionID uint64 + cancelFunc context.CancelFunc + backendConn *BackendConnection + nsmgr *namespace.NamespaceManager + handshakeHandler HandshakeHandler + getBackendIO backendIOGetter + connectionID uint64 } // NewBackendConnManager creates a BackendConnManager. -func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager, connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager { +func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager, handshakeHandler HandshakeHandler, + connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager { mgr := &BackendConnManager{ - logger: logger, - connectionID: connectionID, - cmdProcessor: NewCmdProcessor(), - nsmgr: nsmgr, + logger: logger, + connectionID: connectionID, + cmdProcessor: NewCmdProcessor(), + nsmgr: nsmgr, + handshakeHandler: handshakeHandler, authenticator: &Authenticator{ - supportedServerCapabilities: supportedServerCapabilities, + supportedServerCapabilities: handshakeHandler.GetCapability(), proxyProtocol: proxyProtocol, requireBackendTLS: requireBackendTLS, salt: GenerateSalt(20), @@ -103,13 +108,10 @@ func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager signalReceived: make(chan struct{}, 1), redirectResCh: make(chan *redirectResult, 1), } - mgr.getBackendIO = func(auth *Authenticator) (*pnet.PacketIO, error) { - ns, ok := mgr.nsmgr.GetNamespaceByUser(auth.user) - if !ok { - ns, ok = mgr.nsmgr.GetNamespace("default") - } - if !ok { - return nil, errors.New("failed to find a namespace") + mgr.getBackendIO = func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { + ns, err := handshakeHandler.GetNamespace(nsmgr, resp) + if err != nil { + return nil, err } router := ns.GetRouter() addr, err := router.Route(mgr) @@ -135,7 +137,8 @@ func (mgr *BackendConnManager) ConnectionID() uint64 { } // Connect connects to the first backend and then start watching redirection signals. -func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.PacketIO, getBackendIO func(auth *Authenticator) (*pnet.PacketIO, error), frontendTLSConfig, backendTLSConfig *tls.Config) error { +func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.PacketIO, getBackendIO backendIOGetter, + frontendTLSConfig, backendTLSConfig *tls.Config) error { mgr.processLock.Lock() defer mgr.processLock.Unlock() @@ -143,7 +146,8 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe getBackendIO = mgr.getBackendIO } - if err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, getBackendIO, frontendTLSConfig, backendTLSConfig); err != nil { + if err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler, + getBackendIO, frontendTLSConfig, backendTLSConfig); err != nil { return err } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 99b941f5..fbbb585b 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -116,7 +116,7 @@ func newBackendMgrTester(t *testing.T) *backendMgrTester { return tester } -func (ts *backendMgrTester) getBackendIO(auth *Authenticator) (*pnet.PacketIO, error) { +func (ts *backendMgrTester) getBackendIO(auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) { addr := ts.tc.backendListener.Addr().String() ts.mp.backendConn = NewBackendConnection(addr) if err := ts.mp.backendConn.Connect(); err != nil { diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go new file mode 100644 index 00000000..978e64ad --- /dev/null +++ b/pkg/proxy/backend/handshake_handler.go @@ -0,0 +1,55 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "github.com/pingcap/TiProxy/lib/util/errors" + "github.com/pingcap/TiProxy/pkg/manager/namespace" + pnet "github.com/pingcap/TiProxy/pkg/proxy/net" +) + +var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) + +type HandshakeHandler interface { + HandleHandshakeResp(resp *pnet.HandshakeResp, sourceAddr string) error + GetCapability() pnet.Capability + GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) +} + +type DefaultHandshakeHandler struct { +} + +func NewDefaultHandshakeHandler() *DefaultHandshakeHandler { + return &DefaultHandshakeHandler{} +} + +func (handler *DefaultHandshakeHandler) HandleHandshakeResp(*pnet.HandshakeResp, string) error { + return nil +} + +func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability { + return SupportedServerCapabilities +} + +func (handler *DefaultHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) { + ns, ok := nsMgr.GetNamespaceByUser(resp.User) + if !ok { + ns, ok = nsMgr.GetNamespace("default") + } + if !ok { + return nil, errors.New("failed to find a namespace") + } + return ns, nil +} diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index d1154365..dabde438 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -59,8 +59,8 @@ type mockBackend struct { // Outputs that received from the client and will be checked by the test. username string db string + attrs map[string]string authData []byte - attrs []byte clientCapability uint32 } diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index 32733a83..de2504ea 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -29,8 +29,8 @@ type clientConfig struct { username string dbName string authPlugin string + attrs map[string]string dataBytes []byte - attrs []byte authData []byte filePkts int prepStmtID int @@ -49,7 +49,7 @@ func newClientConfig() *clientConfig { dbName: mockDBName, authPlugin: mysql.AuthCachingSha2Password, authData: mockAuthData, - attrs: make([]byte, 0), + attrs: nil, cmd: mysql.ComQuery, dataBytes: mockCmdBytes, sql: mockCmdStr, diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index ff77d155..3c346ece 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -20,6 +20,7 @@ import ( gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/logger" + "github.com/pingcap/TiProxy/pkg/manager/namespace" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "go.uber.org/zap" ) @@ -27,6 +28,7 @@ import ( type proxyConfig struct { frontendTLSConfig *tls.Config backendTLSConfig *tls.Config + handler HandshakeHandler sessionToken string capability uint32 waitRedirect bool @@ -34,6 +36,7 @@ type proxyConfig struct { func newProxyConfig() *proxyConfig { return &proxyConfig{ + handler: NewDefaultHandshakeHandler(), capability: defaultTestBackendCapability, sessionToken: mockToken, } @@ -54,14 +57,14 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { mp := &mockProxy{ proxyConfig: cfg, logger: logger.CreateLoggerForTest(t).Named("mockProxy"), - BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), nil, 0, false, false), + BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), nil, cfg.handler, 0, false, false), } mp.cmdProcessor.capability = cfg.capability return mp } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - return mp.authenticator.handshakeFirstTime(mp.logger, clientIO, func(_ *Authenticator) (*pnet.PacketIO, error) { + return mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(*Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { return backendIO, nil }, mp.frontendTLSConfig, mp.backendTLSConfig) } @@ -91,3 +94,26 @@ func (mp *mockProxy) directQuery(_, backendIO *pnet.PacketIO) error { mp.rs = rs return err } + +type CustomHandshakeHandler struct { + inUsername string + inAddr string + outUsername string + outAttrs map[string]string +} + +func (handler *CustomHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) { + return &namespace.Namespace{}, nil +} + +func (handler *CustomHandshakeHandler) HandleHandshakeResp(resp *pnet.HandshakeResp, addr string) error { + handler.inUsername = resp.User + resp.User = handler.outUsername + handler.inAddr = addr + resp.Attrs = handler.outAttrs + return nil +} + +func (handler *CustomHandshakeHandler) GetCapability() pnet.Capability { + return SupportedServerCapabilities & ^pnet.ClientDeprecateEOF +} diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index 4f100f91..d50fe40c 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -40,8 +40,9 @@ type ClientConnection struct { connMgr *backend.BackendConnManager } -func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config, nsmgr *namespace.NamespaceManager, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection { - bemgr := backend.NewBackendConnManager(logger.Named("be"), nsmgr, connID, proxyProtocol, requireBackendTLS) +func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config, + nsmgr *namespace.NamespaceManager, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection { + bemgr := backend.NewBackendConnManager(logger.Named("be"), nsmgr, backend.NewDefaultHandshakeHandler(), connID, proxyProtocol, requireBackendTLS) opts := make([]pnet.PacketIOption, 0, 2) opts = append(opts, pnet.WithWrapError(ErrClientConn)) if proxyProtocol { diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index f9a4944f..ac58cd66 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -64,10 +64,10 @@ func ParseInitialHandshake(data []byte) uint32 { // HandshakeResp indicates the response read from the client. type HandshakeResp struct { + Attrs map[string]string User string DB string AuthPlugin string - Attrs []byte AuthData []byte Capability uint32 Collation uint8 @@ -137,17 +137,56 @@ func ParseHandshakeResponse(data []byte) *HandshakeResp { if resp.Capability&mysql.ClientConnectAtts > 0 { if num, null, off := ParseLengthEncodedInt(data[pos:]); !null { pos += off - resp.Attrs = data[pos : pos+int(num)] + row := data[pos : pos+int(num)] + attrs, err := parseAttrs(row) + if err != nil { + return nil + } + resp.Attrs = attrs } } return resp } +func parseAttrs(data []byte) (map[string]string, error) { + attrs := make(map[string]string) + pos := 0 + for pos < len(data) { + key, _, off, err := ParseLengthEncodedBytes(data[pos:]) + if err != nil { + return attrs, err + } + pos += off + value, _, off, err := ParseLengthEncodedBytes(data[pos:]) + if err != nil { + return attrs, err + } + pos += off + + attrs[string(key)] = string(value) + } + return attrs, nil +} + +func dumpAttrs(attrs map[string]string) []byte { + var buf bytes.Buffer + var keyBuf []byte + for k, v := range attrs { + keyBuf = keyBuf[0:0] + keyBuf = DumpLengthEncodedString(keyBuf, []byte(k)) + buf.Write(keyBuf) + keyBuf = keyBuf[0:0] + keyBuf = DumpLengthEncodedString(keyBuf, []byte(v)) + buf.Write(keyBuf) + } + return buf.Bytes() +} + func MakeHandshakeResponse(resp *HandshakeResp) []byte { // encode length of the auth data var ( - authRespBuf, attrRespBuf [9]byte - authResp, attrResp []byte + authRespBuf, attrLenBuf [9]byte + authResp, attrs, attrBuf []byte ) authResp = DumpLengthEncodedInt(authRespBuf[:0], uint64(len(resp.AuthData))) capability := resp.Capability @@ -157,10 +196,11 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { capability &= ^mysql.ClientPluginAuthLenencClientData } if capability&mysql.ClientConnectAtts > 0 { - attrResp = DumpLengthEncodedInt(attrRespBuf[:0], uint64(len(resp.Attrs))) + attrs = dumpAttrs(resp.Attrs) + attrBuf = DumpLengthEncodedInt(attrLenBuf[:0], uint64(len(attrs))) } - length := 4 + 4 + 1 + 23 + len(resp.User) + 1 + len(authResp) + len(resp.AuthData) + len(resp.DB) + 1 + len(resp.AuthPlugin) + 1 + len(attrResp) + len(resp.Attrs) + length := 4 + 4 + 1 + 23 + len(resp.User) + 1 + len(authResp) + len(resp.AuthData) + len(resp.DB) + 1 + len(resp.AuthPlugin) + 1 + len(attrBuf) + len(attrs) data := make([]byte, length) pos := 0 // capability [32 bit] @@ -209,8 +249,8 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { // attrs if capability&mysql.ClientConnectAtts > 0 { - pos += copy(data[pos:], attrResp) - pos += copy(data[pos:], resp.Attrs) + pos += copy(data[pos:], attrBuf) + pos += copy(data[pos:], attrs) } return data[:pos] } diff --git a/pkg/proxy/net/mysql_test.go b/pkg/proxy/net/mysql_test.go new file mode 100644 index 00000000..0efe2748 --- /dev/null +++ b/pkg/proxy/net/mysql_test.go @@ -0,0 +1,37 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "testing" + + "github.com/pingcap/tidb/parser/mysql" + "github.com/stretchr/testify/require" +) + +func TestHandshakeResp(t *testing.T) { + resp1 := &HandshakeResp{ + Attrs: map[string]string{"key": "value"}, + User: "user", + DB: "db", + AuthPlugin: "plugin", + AuthData: []byte("1234567890"), + Capability: ^mysql.ClientPluginAuthLenencClientData, + Collation: 0, + } + b := MakeHandshakeResponse(resp1) + resp2 := ParseHandshakeResponse(b) + require.Equal(t, resp1, resp2) +} diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index d9a37cc2..5071acdd 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -122,6 +122,14 @@ func (p *PacketIO) RemoteAddr() net.Addr { return p.conn.RemoteAddr() } +// SourceAddr returns the source address if proxy protocol is enabled. +func (p *PacketIO) SourceAddr() net.Addr { + if proxy := p.Proxy(); proxy != nil { + return proxy.SrcAddress + } + return p.conn.RemoteAddr() +} + func (p *PacketIO) ResetSequence() { p.sequence = 0 } diff --git a/pkg/proxy/net/protocol.go b/pkg/proxy/net/protocol.go index 76a02ef2..e8a7378f 100644 --- a/pkg/proxy/net/protocol.go +++ b/pkg/proxy/net/protocol.go @@ -35,7 +35,10 @@ package net -import "bytes" +import ( + "bytes" + "io" +) func ParseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { switch b[0] { @@ -75,6 +78,23 @@ func ParseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { return } +func ParseLengthEncodedBytes(b []byte) ([]byte, bool, int, error) { + // Get length + num, isNull, n := ParseLengthEncodedInt(b) + if num < 1 { + return nil, isNull, n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return b[n-int(num) : n], false, n, nil + } + + return nil, false, n, io.EOF +} + func ParseNullTermString(b []byte) (str []byte, remain []byte) { off := bytes.IndexByte(b, 0) if off == -1 { From 09ff5096c6a824d1685313bc7e5fae479edd3078 Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Wed, 30 Nov 2022 15:16:37 +0800 Subject: [PATCH 2/2] add one more test --- pkg/proxy/backend/authenticator_test.go | 7 ++- pkg/proxy/backend/backend_conn_mgr_test.go | 57 ++++++++++++++++++++-- pkg/proxy/backend/mock_backend_test.go | 11 ++--- pkg/proxy/backend/mock_client_test.go | 9 ++-- pkg/proxy/backend/mock_proxy_test.go | 19 +++++--- 5 files changed, 80 insertions(+), 23 deletions(-) diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index f00a9d77..320b1578 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -19,6 +19,7 @@ import ( "strings" "testing" + pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -212,8 +213,9 @@ func TestSecondHandshake(t *testing.T) { func TestCustomAuth(t *testing.T) { tc := newTCPConnSuite(t) handler := &CustomHandshakeHandler{ - outUsername: "rewritten_user", - outAttrs: map[string]string{"key": "value"}, + outUsername: "rewritten_user", + outAttrs: map[string]string{"key": "value"}, + outCapability: SupportedServerCapabilities & ^pnet.ClientDeprecateEOF, } ts, clean := newTestSuite(t, tc, func(cfg *testConfig) { cfg.proxyConfig.handler = handler @@ -222,6 +224,7 @@ func TestCustomAuth(t *testing.T) { require.Equal(t, ts.mc.username, handler.inUsername) require.Equal(t, handler.outUsername, ts.mb.username) require.Equal(t, handler.outAttrs, ts.mb.attrs) + require.Equal(t, handler.outCapability&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF) host, _, err := net.SplitHostPort(handler.inAddr) require.NoError(t, err) require.Equal(t, host, "::1") diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index fbbb585b..f294b6e0 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -91,12 +91,12 @@ type backendMgrTester struct { closed bool } -func newBackendMgrTester(t *testing.T) *backendMgrTester { +func newBackendMgrTester(t *testing.T, cfg ...cfgOverrider) *backendMgrTester { tc := newTCPConnSuite(t) - cfg := func(cfg *testConfig) { + cfg = append(cfg, func(cfg *testConfig) { cfg.testSuiteConfig.initBackendConn = false - } - ts, clean := newTestSuite(t, tc, cfg) + }) + ts, clean := newTestSuite(t, tc, cfg...) tester := &backendMgrTester{ testSuite: ts, t: t, @@ -500,7 +500,7 @@ func TestSpecialCmds(t *testing.T) { require.Equal(t, "another_user", ts.mb.username) require.Equal(t, "session_db", ts.mb.db) expectCap := pnet.Capability(ts.mp.authenticator.supportedServerCapabilities.Uint32() &^ (mysql.ClientMultiStatements | mysql.ClientPluginAuthLenencClientData)) - gotCap := pnet.Capability(ts.mb.clientCapability &^ mysql.ClientPluginAuthLenencClientData) + gotCap := pnet.Capability(ts.mb.capability &^ mysql.ClientPluginAuthLenencClientData) require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap) return nil }, @@ -546,3 +546,50 @@ func TestCloseWhileRedirect(t *testing.T) { } ts.runTests(runners) } + +func TestCustomHandshake(t *testing.T) { + handler := &CustomHandshakeHandler{ + outUsername: "rewritten_user", + outAttrs: map[string]string{"key": "value"}, + outCapability: SupportedServerCapabilities & ^pnet.ClientDeprecateEOF, + } + ts := newBackendMgrTester(t, func(cfg *testConfig) { + //cfg.clientConfig.capability = handler.outCapability + cfg.proxyConfig.handler = handler + }) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // query + { + client: func(packetIO *pnet.PacketIO) error { + ts.mc.sql = "select 1" + return ts.mc.request(packetIO) + }, + proxy: ts.forwardCmd4Proxy, + backend: func(packetIO *pnet.PacketIO) error { + ts.mb.respondType = responseTypeResultSet + ts.mb.columns = 1 + ts.mb.rows = 1 + return ts.mb.respond(packetIO) + }, + }, + // 2nd handshake + { + client: nil, + proxy: func(_, _ *pnet.PacketIO) error { + backend1 := ts.mp.backendConn + ts.mp.Redirect(ts.tc.backendListener.Addr().String()) + ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed) + require.NotEqual(t, backend1, ts.mp.backendConn) + return nil + }, + backend: ts.redirectSucceed4Backend, + }, + } + ts.runTests(runners) +} diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index dabde438..f29177e7 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -57,11 +57,10 @@ type mockBackend struct { // Inputs that assigned by the test and will be sent to the client. *backendConfig // Outputs that received from the client and will be checked by the test. - username string - db string - attrs map[string]string - authData []byte - clientCapability uint32 + username string + db string + attrs map[string]string + authData []byte } func newMockBackend(cfg *backendConfig) *mockBackend { @@ -101,7 +100,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { mb.db = resp.DB mb.authData = resp.AuthData mb.attrs = resp.Attrs - mb.clientCapability = resp.Capability + mb.capability = resp.Capability // verify password return mb.verifyPassword(packetIO, resp) } diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index de2504ea..bb28cb64 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -49,7 +49,7 @@ func newClientConfig() *clientConfig { dbName: mockDBName, authPlugin: mysql.AuthCachingSha2Password, authData: mockAuthData, - attrs: nil, + attrs: make(map[string]string), cmd: mysql.ComQuery, dataBytes: mockCmdBytes, sql: mockCmdStr, @@ -74,9 +74,12 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error { if mc.abnormalExit { return packetIO.Close() } - if _, err := packetIO.ReadPacket(); err != nil { + pkt, err := packetIO.ReadPacket() + if err != nil { return err } + serverCap := pnet.ParseInitialHandshake(pkt) + mc.capability = mc.capability & serverCap resp := &pnet.HandshakeResp{ User: mc.username, @@ -87,7 +90,7 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error { Capability: mc.capability, Collation: mc.collation, } - pkt := pnet.MakeHandshakeResponse(resp) + pkt = pnet.MakeHandshakeResponse(resp) if mc.capability&mysql.ClientSSL > 0 { if err := packetIO.WritePacket(pkt[:32], true); err != nil { return err diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 3c346ece..3942b84a 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -64,9 +64,13 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - return mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(*Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { + if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(*Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { return backendIO, nil - }, mp.frontendTLSConfig, mp.backendTLSConfig) + }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { + return err + } + mp.cmdProcessor.capability = mp.authenticator.capability + return nil } func (mp *mockProxy) authenticateSecondTime(clientIO, backendIO *pnet.PacketIO) error { @@ -96,10 +100,11 @@ func (mp *mockProxy) directQuery(_, backendIO *pnet.PacketIO) error { } type CustomHandshakeHandler struct { - inUsername string - inAddr string - outUsername string - outAttrs map[string]string + inUsername string + inAddr string + outCapability pnet.Capability + outUsername string + outAttrs map[string]string } func (handler *CustomHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) { @@ -115,5 +120,5 @@ func (handler *CustomHandshakeHandler) HandleHandshakeResp(resp *pnet.HandshakeR } func (handler *CustomHandshakeHandler) GetCapability() pnet.Capability { - return SupportedServerCapabilities & ^pnet.ClientDeprecateEOF + return handler.outCapability }