diff --git a/go.mod b/go.mod index 2ef64a59..3240ffee 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/pingcap/TiProxy go 1.19 require ( + github.com/cenkalti/backoff/v4 v4.2.0 github.com/gin-contrib/pprof v1.4.0 github.com/gin-contrib/zap v0.0.2 github.com/gin-gonic/gin v1.8.1 @@ -31,7 +32,6 @@ require ( github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 // indirect github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/cenkalti/backoff/v4 v4.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cockroachdb/datadriven v1.0.0 // indirect github.com/coocood/freecache v1.2.1 // indirect diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 6954383a..d6d6a41b 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -51,6 +51,7 @@ type Authenticator struct { supportedServerCapabilities pnet.Capability dbname string // default database name serverAddr string + clientAddr string user string attrs map[string]string salt []byte @@ -103,6 +104,7 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, getBackend backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() + auth.clientAddr = clientIO.SourceAddr().String() proxyCapability := auth.supportedServerCapabilities if frontendTLSConfig == nil { @@ -147,7 +149,6 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet auth.capability = commonCaps.Uint32() resp := pnet.ParseHandshakeResponse(pkt) - auth.SetValue(ContextKeyClientAddr, clientIO.SourceAddr().String()) if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil { return err } @@ -349,6 +350,14 @@ func (auth *Authenticator) updateCurrentDB(db string) { auth.dbname = db } +func (auth *Authenticator) ClientAddr() string { + return auth.clientAddr +} + +func (auth *Authenticator) ServerAddr() string { + return auth.serverAddr +} + func (auth *Authenticator) SetValue(key, val any) { auth.ctxmap.Store(key, val) } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index ebd7ced2..9dc13b9a 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/cenkalti/backoff/v4" gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/lib/util/waitgroup" @@ -90,7 +91,6 @@ type BackendConnManager struct { handshakeHandler HandshakeHandler getBackendIO backendIOGetter connectionID uint64 - handshaked bool } // NewBackendConnManager creates a BackendConnManager. @@ -116,24 +116,34 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler return nil, err } // wait for initialize - var addr string - for start := time.Now(); time.Since(start) < time.Second*5; { - addr, err = r.Route(mgr) - if !errors.Is(err, router.ErrNoInstanceToSelect) { - break - } - time.Sleep(time.Millisecond * 200) - } + bctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + addr, err := backoff.RetryNotifyWithData( + func() (string, error) { + addr, err := r.Route(mgr) + if !errors.Is(err, router.ErrNoInstanceToSelect) { + return addr, backoff.Permanent(err) + } + return addr, err + }, + backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), + func(err error, d time.Duration) { + mgr.handshakeHandler.OnHandshake(ctx, "", err) + }, + ) + cancel() if err != nil { return nil, err } + mgr.logger.Info("found", zap.String("addr", addr)) mgr.backendConn = NewBackendConnection(addr) if err := mgr.backendConn.Connect(); err != nil { + mgr.handshakeHandler.OnHandshake(ctx, addr, err) return nil, err } - backendIO := mgr.backendConn.PacketIO() + auth.serverAddr = addr + backendIO := mgr.backendConn.PacketIO() return backendIO, nil } return mgr @@ -155,8 +165,10 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe getBackendIO = mgr.getBackendIO } - if err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler, - getBackendIO, frontendTLSConfig, backendTLSConfig); err != nil { + err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler, + getBackendIO, frontendTLSConfig, backendTLSConfig) + mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, err) + if err != nil { return err } @@ -166,7 +178,6 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.wg.Run(func() { mgr.processSignals(childCtx, clientIO) }) - mgr.handshaked = true return nil } @@ -319,10 +330,15 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P newConn := NewBackendConnection(rs.to) if rs.err = newConn.Connect(); rs.err != nil { + mgr.handshakeHandler.OnHandshake(mgr.authenticator, rs.to, rs.err) return } + mgr.authenticator.serverAddr = rs.to + mgr.authenticator.clientAddr = clientIO.SourceAddr().String() if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, clientIO, newConn.PacketIO(), sessionToken); rs.err == nil { rs.err = mgr.initSessionStates(newConn.PacketIO(), sessionStates) + } else { + mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, rs.err) } if rs.err != nil { if ignoredErr := newConn.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { @@ -334,6 +350,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P mgr.logger.Error("close previous backend connection failed", zap.Error(ignoredErr)) } mgr.backendConn = newConn + mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, nil) } // The original db in the auth info may be dropped during the session, so we need to authenticate with the current db. @@ -411,7 +428,7 @@ func (mgr *BackendConnManager) Close() error { } mgr.processLock.Unlock() - handErr := mgr.handshakeHandler.OnConnClose(mgr.authenticator, mgr.handshaked) + handErr := mgr.handshakeHandler.OnConnClose(mgr.authenticator) eventReceiver := mgr.getEventReceiver() if eventReceiver != nil { diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index f85d7698..802fa351 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -21,20 +21,12 @@ import ( pnet "github.com/pingcap/TiProxy/pkg/proxy/net" ) -type contextKey string - -func (k contextKey) String() string { - return "handler context key " + string(k) -} - // Context keys. -var ( - ContextKeyClientAddr contextKey = "client_addr" -) - var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) type ConnContext interface { + ClientAddr() string + ServerAddr() string SetValue(key, val any) Value(key any) any } @@ -42,7 +34,8 @@ type ConnContext interface { type HandshakeHandler interface { HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) - OnConnClose(ctx ConnContext, handshaked bool) error + OnHandshake(ctx ConnContext, to string, err error) + OnConnClose(ctx ConnContext) error GetCapability() pnet.Capability } @@ -71,7 +64,10 @@ func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Ha return ns.GetRouter(), nil } -func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext, bool) error { +func (handler *DefaultHandshakeHandler) OnHandshake(ConnContext, string, error) { +} + +func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext) error { return nil } diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 2698e472..a7a50c43 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -111,14 +111,17 @@ func (handler *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Han return nil, nil } -func (handler *CustomHandshakeHandler) OnConnClose(ctx ConnContext, _ bool) error { +func (handler *CustomHandshakeHandler) OnHandshake(ctx ConnContext, _ string, _ error) { +} + +func (handler *CustomHandshakeHandler) OnConnClose(ctx ConnContext) error { return nil } func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error { handler.inUsername = resp.User resp.User = handler.outUsername - handler.inAddr = ctx.Value(ContextKeyClientAddr).(string) + handler.inAddr = ctx.ClientAddr() resp.Attrs = handler.outAttrs return nil }