Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Add HandshakeHandler to customize handling handshake #138

Merged
merged 2 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ const unknownAuthPlugin = "auth_unknown_plugin"
const requiredFrontendCaps = pnet.ClientProtocol41
const defRequiredBackendCaps = pnet.ClientDeprecateEOF

// Other server capabilities are not supported. ClientDeprecateEOF is supported but TiDB 6.2.0 doesn't support it now.
const supportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB |
// SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported.
// TiDB supports ClientDeprecateEOF since v6.3.0.
const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB |
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag | pnet.ClientSSL |
pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements |
pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData |
Expand All @@ -49,7 +50,7 @@ type Authenticator struct {
dbname string // default database name
serverAddr string
user string
attrs []byte // no need to parse
attrs map[string]string
salt []byte
capability uint32 // client capability
collation uint8
Expand All @@ -72,7 +73,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
Version: pnet.ProxyVersion2,
}
}
// either from another proxy or directly from clients, we are actings as a proxy
// either from another proxy or directly from clients, we are acting as a proxy
proxy.Command = pnet.ProxyCommandProxy
if err := backendIO.WriteProxyV2(proxy); err != nil {
return err
Expand All @@ -82,7 +83,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
}

func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error {
requiredBackendCaps := defRequiredBackendCaps
requiredBackendCaps := defRequiredBackendCaps & pnet.Capability(auth.capability)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? We don't do the verification for client and server, I think...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verifyBackendCaps verifies that TiDB also has capability ClientDeprecateEOF. But if TiProxy doesn't enable ClientDeprecateEOF with the client, it shouldn't require ClientDeprecateEOF from the TiDB.

if auth.requireBackendTLS {
requiredBackendCaps |= pnet.ClientSSL
}
Expand All @@ -97,7 +98,8 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili
return nil
}

func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, getBackendIO func(*Authenticator) (*pnet.PacketIO, error), frontendTLSConfig, backendTLSConfig *tls.Config) error {
func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler,
getBackend backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error {
clientIO.ResetSequence()

proxyCapability := auth.supportedServerCapabilities
Expand Down Expand Up @@ -140,14 +142,18 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
if frontendCapability^commonCaps != 0 {
logger.Debug("frontend send capabilities unsupported by proxy", zap.Stringer("common", commonCaps), zap.Stringer("frontend", frontendCapability^commonCaps), zap.Stringer("proxy", proxyCapability^commonCaps))
}
resp := pnet.ParseHandshakeResponse(pkt)
auth.capability = commonCaps.Uint32()

resp := pnet.ParseHandshakeResponse(pkt)
if err = handshakeHandler.HandleHandshakeResp(resp, clientIO.SourceAddr().String()); err != nil {
return err
}
auth.user = resp.User
auth.dbname = resp.DB
auth.collation = resp.Collation
auth.attrs = resp.Attrs

backendIO, err := getBackendIO(auth)
backendIO, err := getBackend(auth, resp)
if err != nil {
return err
}
Expand Down
30 changes: 29 additions & 1 deletion pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
package backend

import (
"net"
"strings"
"testing"

pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -162,7 +164,7 @@ func TestCapability(t *testing.T) {
},
func(cfg *testConfig) {
cfg.clientConfig.capability = defaultTestClientCapability | mysql.ClientConnectAtts
cfg.clientConfig.attrs = []byte(strings.Repeat("x", 512))
cfg.clientConfig.attrs = map[string]string{"key": "value"}
},
},
{
Expand Down Expand Up @@ -207,3 +209,29 @@ func TestSecondHandshake(t *testing.T) {
clean()
}
}

func TestCustomAuth(t *testing.T) {
tc := newTCPConnSuite(t)
handler := &CustomHandshakeHandler{
outUsername: "rewritten_user",
outAttrs: map[string]string{"key": "value"},
outCapability: SupportedServerCapabilities & ^pnet.ClientDeprecateEOF,
}
ts, clean := newTestSuite(t, tc, func(cfg *testConfig) {
cfg.proxyConfig.handler = handler
})
checker := func() {
require.Equal(t, ts.mc.username, handler.inUsername)
require.Equal(t, handler.outUsername, ts.mb.username)
require.Equal(t, handler.outAttrs, ts.mb.attrs)
require.Equal(t, handler.outCapability&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF)
host, _, err := net.SplitHostPort(handler.inAddr)
require.NoError(t, err)
require.Equal(t, host, "::1")
}
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {})
checker()
ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {})
checker()
clean()
}
44 changes: 24 additions & 20 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type redirectResult struct {
to string
}

type backendIOGetter func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error)

// BackendConnManager migrates a session from one BackendConnection to another.
//
// The signal processing goroutine tries to migrate the session once it receives a signal.
Expand All @@ -80,36 +82,36 @@ type BackendConnManager struct {
// redirectResCh is used to notify the event receiver asynchronously.
redirectResCh chan *redirectResult
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
backendConn *BackendConnection
nsmgr *namespace.NamespaceManager
getBackendIO func(*Authenticator) (*pnet.PacketIO, error)
connectionID uint64
cancelFunc context.CancelFunc
backendConn *BackendConnection
nsmgr *namespace.NamespaceManager
handshakeHandler HandshakeHandler
getBackendIO backendIOGetter
connectionID uint64
}

// NewBackendConnManager creates a BackendConnManager.
func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager, connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager {
func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager, handshakeHandler HandshakeHandler,
connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager {
mgr := &BackendConnManager{
logger: logger,
connectionID: connectionID,
cmdProcessor: NewCmdProcessor(),
nsmgr: nsmgr,
logger: logger,
connectionID: connectionID,
cmdProcessor: NewCmdProcessor(),
nsmgr: nsmgr,
handshakeHandler: handshakeHandler,
authenticator: &Authenticator{
supportedServerCapabilities: supportedServerCapabilities,
supportedServerCapabilities: handshakeHandler.GetCapability(),
proxyProtocol: proxyProtocol,
requireBackendTLS: requireBackendTLS,
salt: GenerateSalt(20),
},
signalReceived: make(chan struct{}, 1),
redirectResCh: make(chan *redirectResult, 1),
}
mgr.getBackendIO = func(auth *Authenticator) (*pnet.PacketIO, error) {
ns, ok := mgr.nsmgr.GetNamespaceByUser(auth.user)
if !ok {
ns, ok = mgr.nsmgr.GetNamespace("default")
}
if !ok {
return nil, errors.New("failed to find a namespace")
mgr.getBackendIO = func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) {
ns, err := handshakeHandler.GetNamespace(nsmgr, resp)
if err != nil {
return nil, err
}
router := ns.GetRouter()
addr, err := router.Route(mgr)
Expand All @@ -135,15 +137,17 @@ func (mgr *BackendConnManager) ConnectionID() uint64 {
}

// Connect connects to the first backend and then start watching redirection signals.
func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.PacketIO, getBackendIO func(auth *Authenticator) (*pnet.PacketIO, error), frontendTLSConfig, backendTLSConfig *tls.Config) error {
func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.PacketIO, getBackendIO backendIOGetter,
frontendTLSConfig, backendTLSConfig *tls.Config) error {
mgr.processLock.Lock()
defer mgr.processLock.Unlock()

if getBackendIO == nil {
getBackendIO = mgr.getBackendIO
}

if err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, getBackendIO, frontendTLSConfig, backendTLSConfig); err != nil {
if err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler,
getBackendIO, frontendTLSConfig, backendTLSConfig); err != nil {
return err
}

Expand Down
59 changes: 53 additions & 6 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ type backendMgrTester struct {
closed bool
}

func newBackendMgrTester(t *testing.T) *backendMgrTester {
func newBackendMgrTester(t *testing.T, cfg ...cfgOverrider) *backendMgrTester {
tc := newTCPConnSuite(t)
cfg := func(cfg *testConfig) {
cfg = append(cfg, func(cfg *testConfig) {
cfg.testSuiteConfig.initBackendConn = false
}
ts, clean := newTestSuite(t, tc, cfg)
})
ts, clean := newTestSuite(t, tc, cfg...)
tester := &backendMgrTester{
testSuite: ts,
t: t,
Expand All @@ -116,7 +116,7 @@ func newBackendMgrTester(t *testing.T) *backendMgrTester {
return tester
}

func (ts *backendMgrTester) getBackendIO(auth *Authenticator) (*pnet.PacketIO, error) {
func (ts *backendMgrTester) getBackendIO(auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) {
addr := ts.tc.backendListener.Addr().String()
ts.mp.backendConn = NewBackendConnection(addr)
if err := ts.mp.backendConn.Connect(); err != nil {
Expand Down Expand Up @@ -500,7 +500,7 @@ func TestSpecialCmds(t *testing.T) {
require.Equal(t, "another_user", ts.mb.username)
require.Equal(t, "session_db", ts.mb.db)
expectCap := pnet.Capability(ts.mp.authenticator.supportedServerCapabilities.Uint32() &^ (mysql.ClientMultiStatements | mysql.ClientPluginAuthLenencClientData))
gotCap := pnet.Capability(ts.mb.clientCapability &^ mysql.ClientPluginAuthLenencClientData)
gotCap := pnet.Capability(ts.mb.capability &^ mysql.ClientPluginAuthLenencClientData)
require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap)
return nil
},
Expand Down Expand Up @@ -546,3 +546,50 @@ func TestCloseWhileRedirect(t *testing.T) {
}
ts.runTests(runners)
}

func TestCustomHandshake(t *testing.T) {
handler := &CustomHandshakeHandler{
outUsername: "rewritten_user",
outAttrs: map[string]string{"key": "value"},
outCapability: SupportedServerCapabilities & ^pnet.ClientDeprecateEOF,
}
ts := newBackendMgrTester(t, func(cfg *testConfig) {
//cfg.clientConfig.capability = handler.outCapability
cfg.proxyConfig.handler = handler
})
runners := []runner{
// 1st handshake
{
client: ts.mc.authenticate,
proxy: ts.firstHandshake4Proxy,
backend: ts.handshake4Backend,
},
// query
{
client: func(packetIO *pnet.PacketIO) error {
ts.mc.sql = "select 1"
return ts.mc.request(packetIO)
},
proxy: ts.forwardCmd4Proxy,
backend: func(packetIO *pnet.PacketIO) error {
ts.mb.respondType = responseTypeResultSet
ts.mb.columns = 1
ts.mb.rows = 1
return ts.mb.respond(packetIO)
},
},
// 2nd handshake
{
client: nil,
proxy: func(_, _ *pnet.PacketIO) error {
backend1 := ts.mp.backendConn
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendConn)
return nil
},
backend: ts.redirectSucceed4Backend,
},
}
ts.runTests(runners)
}
55 changes: 55 additions & 0 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import (
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/pkg/manager/namespace"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
)

var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil)

type HandshakeHandler interface {
HandleHandshakeResp(resp *pnet.HandshakeResp, sourceAddr string) error
GetCapability() pnet.Capability
GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error)
}

type DefaultHandshakeHandler struct {
}

func NewDefaultHandshakeHandler() *DefaultHandshakeHandler {
return &DefaultHandshakeHandler{}
}

func (handler *DefaultHandshakeHandler) HandleHandshakeResp(*pnet.HandshakeResp, string) error {
return nil
}

func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability {
return SupportedServerCapabilities
}

func (handler *DefaultHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) {
ns, ok := nsMgr.GetNamespaceByUser(resp.User)
if !ok {
ns, ok = nsMgr.GetNamespace("default")
}
if !ok {
return nil, errors.New("failed to find a namespace")
}
return ns, nil
}
11 changes: 5 additions & 6 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ type mockBackend struct {
// Inputs that assigned by the test and will be sent to the client.
*backendConfig
// Outputs that received from the client and will be checked by the test.
username string
db string
authData []byte
attrs []byte
clientCapability uint32
username string
db string
attrs map[string]string
authData []byte
}

func newMockBackend(cfg *backendConfig) *mockBackend {
Expand Down Expand Up @@ -101,7 +100,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
mb.db = resp.DB
mb.authData = resp.AuthData
mb.attrs = resp.Attrs
mb.clientCapability = resp.Capability
mb.capability = resp.Capability
// verify password
return mb.verifyPassword(packetIO, resp)
}
Expand Down
Loading