Skip to content

Commit

Permalink
backend: add OnHandshake (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox authored Dec 26, 2022
1 parent b5d2300 commit 2c9d13a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 30 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/pingcap/TiProxy
go 1.19

require (
github.com/cenkalti/backoff/v4 v4.2.0
github.com/gin-contrib/pprof v1.4.0
github.com/gin-contrib/zap v0.0.2
github.com/gin-gonic/gin v1.8.1
Expand Down Expand Up @@ -31,7 +32,6 @@ require (
github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 // indirect
github.com/benbjohnson/clock v1.3.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.2.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cockroachdb/datadriven v1.0.0 // indirect
github.com/coocood/freecache v1.2.1 // indirect
Expand Down
11 changes: 10 additions & 1 deletion pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type Authenticator struct {
supportedServerCapabilities pnet.Capability
dbname string // default database name
serverAddr string
clientAddr string
user string
attrs map[string]string
salt []byte
Expand Down Expand Up @@ -103,6 +104,7 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili
func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler,
getBackend backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error {
clientIO.ResetSequence()
auth.clientAddr = clientIO.SourceAddr().String()

proxyCapability := auth.supportedServerCapabilities
if frontendTLSConfig == nil {
Expand Down Expand Up @@ -147,7 +149,6 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
auth.capability = commonCaps.Uint32()

resp := pnet.ParseHandshakeResponse(pkt)
auth.SetValue(ContextKeyClientAddr, clientIO.SourceAddr().String())
if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil {
return err
}
Expand Down Expand Up @@ -349,6 +350,14 @@ func (auth *Authenticator) updateCurrentDB(db string) {
auth.dbname = db
}

func (auth *Authenticator) ClientAddr() string {
return auth.clientAddr
}

func (auth *Authenticator) ServerAddr() string {
return auth.serverAddr
}

func (auth *Authenticator) SetValue(key, val any) {
auth.ctxmap.Store(key, val)
}
Expand Down
45 changes: 31 additions & 14 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"time"
"unsafe"

"github.com/cenkalti/backoff/v4"
gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/lib/util/waitgroup"
Expand Down Expand Up @@ -90,7 +91,6 @@ type BackendConnManager struct {
handshakeHandler HandshakeHandler
getBackendIO backendIOGetter
connectionID uint64
handshaked bool
}

// NewBackendConnManager creates a BackendConnManager.
Expand All @@ -116,24 +116,34 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler
return nil, err
}
// wait for initialize
var addr string
for start := time.Now(); time.Since(start) < time.Second*5; {
addr, err = r.Route(mgr)
if !errors.Is(err, router.ErrNoInstanceToSelect) {
break
}
time.Sleep(time.Millisecond * 200)
}
bctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
addr, err := backoff.RetryNotifyWithData(
func() (string, error) {
addr, err := r.Route(mgr)
if !errors.Is(err, router.ErrNoInstanceToSelect) {
return addr, backoff.Permanent(err)
}
return addr, err
},
backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx),
func(err error, d time.Duration) {
mgr.handshakeHandler.OnHandshake(ctx, "", err)
},
)
cancel()
if err != nil {
return nil, err
}

mgr.logger.Info("found", zap.String("addr", addr))
mgr.backendConn = NewBackendConnection(addr)
if err := mgr.backendConn.Connect(); err != nil {
mgr.handshakeHandler.OnHandshake(ctx, addr, err)
return nil, err
}
backendIO := mgr.backendConn.PacketIO()

auth.serverAddr = addr
backendIO := mgr.backendConn.PacketIO()
return backendIO, nil
}
return mgr
Expand All @@ -155,8 +165,10 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
getBackendIO = mgr.getBackendIO
}

if err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler,
getBackendIO, frontendTLSConfig, backendTLSConfig); err != nil {
err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler,
getBackendIO, frontendTLSConfig, backendTLSConfig)
mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, err)
if err != nil {
return err
}

Expand All @@ -166,7 +178,6 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
mgr.wg.Run(func() {
mgr.processSignals(childCtx, clientIO)
})
mgr.handshaked = true
return nil
}

Expand Down Expand Up @@ -319,10 +330,15 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P

newConn := NewBackendConnection(rs.to)
if rs.err = newConn.Connect(); rs.err != nil {
mgr.handshakeHandler.OnHandshake(mgr.authenticator, rs.to, rs.err)
return
}
mgr.authenticator.serverAddr = rs.to
mgr.authenticator.clientAddr = clientIO.SourceAddr().String()
if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, clientIO, newConn.PacketIO(), sessionToken); rs.err == nil {
rs.err = mgr.initSessionStates(newConn.PacketIO(), sessionStates)
} else {
mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, rs.err)
}
if rs.err != nil {
if ignoredErr := newConn.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) {
Expand All @@ -334,6 +350,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P
mgr.logger.Error("close previous backend connection failed", zap.Error(ignoredErr))
}
mgr.backendConn = newConn
mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, nil)
}

// The original db in the auth info may be dropped during the session, so we need to authenticate with the current db.
Expand Down Expand Up @@ -411,7 +428,7 @@ func (mgr *BackendConnManager) Close() error {
}
mgr.processLock.Unlock()

handErr := mgr.handshakeHandler.OnConnClose(mgr.authenticator, mgr.handshaked)
handErr := mgr.handshakeHandler.OnConnClose(mgr.authenticator)

eventReceiver := mgr.getEventReceiver()
if eventReceiver != nil {
Expand Down
20 changes: 8 additions & 12 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,21 @@ import (
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
)

type contextKey string

func (k contextKey) String() string {
return "handler context key " + string(k)
}

// Context keys.
var (
ContextKeyClientAddr contextKey = "client_addr"
)

var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil)

type ConnContext interface {
ClientAddr() string
ServerAddr() string
SetValue(key, val any)
Value(key any) any
}

type HandshakeHandler interface {
HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error
GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error)
OnConnClose(ctx ConnContext, handshaked bool) error
OnHandshake(ctx ConnContext, to string, err error)
OnConnClose(ctx ConnContext) error
GetCapability() pnet.Capability
}

Expand Down Expand Up @@ -71,7 +64,10 @@ func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Ha
return ns.GetRouter(), nil
}

func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext, bool) error {
func (handler *DefaultHandshakeHandler) OnHandshake(ConnContext, string, error) {
}

func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext) error {
return nil
}

Expand Down
7 changes: 5 additions & 2 deletions pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,17 @@ func (handler *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Han
return nil, nil
}

func (handler *CustomHandshakeHandler) OnConnClose(ctx ConnContext, _ bool) error {
func (handler *CustomHandshakeHandler) OnHandshake(ctx ConnContext, _ string, _ error) {
}

func (handler *CustomHandshakeHandler) OnConnClose(ctx ConnContext) error {
return nil
}

func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error {
handler.inUsername = resp.User
resp.User = handler.outUsername
handler.inAddr = ctx.Value(ContextKeyClientAddr).(string)
handler.inAddr = ctx.ClientAddr()
resp.Attrs = handler.outAttrs
return nil
}
Expand Down

0 comments on commit 2c9d13a

Please sign in to comment.