From 876ba28cef905a4b13949f8c27d78601760d0959 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Fri, 30 Sep 2022 14:24:50 +0800 Subject: [PATCH] backend: update db in the handshake (#102) --- .golangci.yaml | 3 -- pkg/proxy/backend/backend_conn_mgr.go | 4 ++ pkg/proxy/backend/backend_conn_mgr_test.go | 3 +- pkg/proxy/backend/mock_backend_test.go | 48 +++++++++++----------- 4 files changed, 30 insertions(+), 28 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index f3e54ad1..a606753d 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -18,9 +18,6 @@ issues: linters: - gosec text: "G402:" - - linters: - - unused - source: "updateAuthInfoFromSessionStates" linters: enable: diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 052fa71a..a417fb6b 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" + "github.com/siddontang/go/hack" "go.uber.org/zap" ) @@ -265,6 +266,9 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { if sessionStates, sessionToken, rs.err = mgr.querySessionStates(); rs.err != nil { return } + if rs.err = mgr.updateAuthInfoFromSessionStates(hack.Slice(sessionStates)); rs.err != nil { + return + } newConn := NewBackendConnection(rs.to) if rs.err = newConn.Connect(); rs.err != nil { diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 7ec5b262..c20a34ce 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -484,9 +484,10 @@ func TestSpecialCmds(t *testing.T) { return nil }, backend: func(packetIO *pnet.PacketIO) error { + ts.mb.sessionStates = "{\"current-db\":\"session_db\"}" require.NoError(t, ts.redirectSucceed4Backend(packetIO)) require.Equal(t, "another_user", ts.mb.username) - require.Equal(t, "another_db", ts.mb.db) + require.Equal(t, "session_db", ts.mb.db) expectCap := pnet.Capability(ts.mp.authenticator.supportedServerCapabilities.Uint32() &^ (mysql.ClientMultiStatements | mysql.ClientPluginAuthLenencClientData)) gotCap := pnet.Capability(ts.mb.clientCapability &^ mysql.ClientPluginAuthLenencClientData) require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap) diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index 3ee11dd3..6682d68c 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -24,33 +24,33 @@ import ( ) type backendConfig struct { - // for auth - tlsConfig *tls.Config - authPlugin string - salt []byte - columns int - loops int - params int - rows int - respondType respondType // for cmd - stmtNum int - capability uint32 - status uint16 - authSucceed bool - switchAuth bool - // for both auth and cmd - abnormalExit bool + tlsConfig *tls.Config + authPlugin string + sessionStates string + salt []byte + columns int + loops int + params int + rows int + respondType respondType + stmtNum int + capability uint32 + status uint16 + authSucceed bool + switchAuth bool + abnormalExit bool } func newBackendConfig() *backendConfig { return &backendConfig{ - capability: defaultTestBackendCapability, - salt: mockSalt, - authPlugin: mysql.AuthCachingSha2Password, - switchAuth: true, - authSucceed: true, - loops: 1, - stmtNum: 1, + capability: defaultTestBackendCapability, + salt: mockSalt, + authPlugin: mysql.AuthCachingSha2Password, + switchAuth: true, + authSucceed: true, + loops: 1, + stmtNum: 1, + sessionStates: mockSessionStates, } } @@ -372,7 +372,7 @@ func (mb *mockBackend) respondSessionStates(packetIO *pnet.PacketIO) error { names := []string{sessionStatesCol, sessionTokenCol} values := [][]any{ { - mockSessionStates, mockCmdStr, + mb.sessionStates, mockCmdStr, }, } return mb.writeResultSet(packetIO, names, values)