diff --git a/pkg/manager/router/backend_selector.go b/pkg/manager/router/backend_selector.go new file mode 100644 index 00000000..ca7a4aee --- /dev/null +++ b/pkg/manager/router/backend_selector.go @@ -0,0 +1,38 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package router + +type BackendSelector struct { + excluded []string + cur string + routeOnce func(excluded []string) string + addConn func(addr string, conn RedirectableConn) error +} + +func (bs *BackendSelector) Reset() { + bs.excluded = bs.excluded[:0] +} + +func (bs *BackendSelector) Next() string { + if len(bs.cur) > 0 { + bs.excluded = append(bs.excluded, bs.cur) + } + bs.cur = bs.routeOnce(bs.excluded) + return bs.cur +} + +func (bs *BackendSelector) Succeed(conn RedirectableConn) error { + return bs.addConn(bs.cur, conn) +} diff --git a/pkg/manager/router/router.go b/pkg/manager/router/router.go index af5c3030..c3bfedde 100644 --- a/pkg/manager/router/router.go +++ b/pkg/manager/router/router.go @@ -37,7 +37,7 @@ type Router interface { // Router will handle connection events to balance connections if possible. ConnEventReceiver - Route(RedirectableConn) (string, error) + GetBackendSelector() BackendSelector RedirectConnections() error ConnCount() int Close() diff --git a/pkg/manager/router/router_score.go b/pkg/manager/router/router_score.go index 72910bbe..2b6091bc 100644 --- a/pkg/manager/router/router_score.go +++ b/pkg/manager/router/router_score.go @@ -61,27 +61,54 @@ func NewScoreBasedRouter(logger *zap.Logger, httpCli *http.Client, fetcher Backe return router, nil } -// Route implements Router.Route interface. -func (router *ScoreBasedRouter) Route(conn RedirectableConn) (string, error) { +// GetBackendSelector implements Router.GetBackendSelector interface. +func (router *ScoreBasedRouter) GetBackendSelector() BackendSelector { + return BackendSelector{ + routeOnce: router.routeOnce, + addConn: router.addNewConn, + } +} + +func (router *ScoreBasedRouter) routeOnce(excluded []string) string { router.Lock() defer router.Unlock() - be := router.backends.Back() - if be == nil { - return "", ErrNoInstanceToSelect - } - backend := be.Value.(*backendWrapper) - switch backend.status { - case StatusCannotConnect, StatusSchemaOutdated: - return "", ErrNoInstanceToSelect + for be := router.backends.Back(); be != nil; be = be.Prev() { + backend := be.Value.(*backendWrapper) + // These backends may be recycled, so we should not connect to them again. + switch backend.status { + case StatusCannotConnect, StatusSchemaOutdated: + continue + } + found := false + for _, ex := range excluded { + if ex == backend.addr { + found = true + break + } + } + if !found { + return backend.addr + } } + return "" +} + +func (router *ScoreBasedRouter) addNewConn(addr string, conn RedirectableConn) error { connWrapper := &connWrapper{ RedirectableConn: conn, phase: phaseNotRedirected, } + router.Lock() + be := router.lookupBackend(addr, true) + if be == nil { + router.Unlock() + return errors.WithStack(errors.Errorf("backend %s is not found in the router", addr)) + } router.addConn(be, connWrapper) - addBackendConnMetrics(backend.addr) + router.Unlock() + addBackendConnMetrics(addr) conn.SetEventReceiver(router) - return backend.addr, nil + return nil } func (router *ScoreBasedRouter) removeConn(be *list.Element, ce *list.Element) { diff --git a/pkg/manager/router/router_static.go b/pkg/manager/router/router_static.go index 8ff6dfb7..afcdfcbe 100644 --- a/pkg/manager/router/router_static.go +++ b/pkg/manager/router/router_static.go @@ -17,16 +17,36 @@ package router var _ Router = &StaticRouter{} type StaticRouter struct { - addr string + addr []string cnt int } -func NewStaticRouter(addr string) *StaticRouter { +func NewStaticRouter(addr []string) *StaticRouter { return &StaticRouter{addr: addr} } -func (r *StaticRouter) Route(c RedirectableConn) (string, error) { - return r.addr, nil +func (r *StaticRouter) GetBackendSelector() BackendSelector { + return BackendSelector{ + routeOnce: func(excluded []string) string { + for _, addr := range r.addr { + found := false + for _, e := range excluded { + if e == addr { + found = true + break + } + } + if !found { + return addr + } + } + return "" + }, + addConn: func(addr string, conn RedirectableConn) error { + r.cnt++ + return nil + }, + } } func (r *StaticRouter) RedirectConnections() error { @@ -34,7 +54,6 @@ func (r *StaticRouter) RedirectConnections() error { } func (r *StaticRouter) ConnCount() int { - r.cnt++ return r.cnt } diff --git a/pkg/manager/router/router_test.go b/pkg/manager/router/router_test.go index 585c9a39..c3ae7ab8 100644 --- a/pkg/manager/router/router_test.go +++ b/pkg/manager/router/router_test.go @@ -171,11 +171,21 @@ func (tester *routerTester) checkBackendOrder() { } } +func (tester *routerTester) simpleRoute(conn RedirectableConn) string { + selector := tester.router.GetBackendSelector() + addr := selector.Next() + if len(addr) > 0 { + err := selector.Succeed(conn) + require.NoError(tester.t, err) + } + return addr +} + func (tester *routerTester) addConnections(num int) { for i := 0; i < num; i++ { conn := tester.createConn() - addr, err := tester.router.Route(conn) - require.NoError(tester.t, err) + addr := tester.simpleRoute(conn) + require.True(tester.t, len(addr) > 0) conn.from = addr tester.conns[conn.connID] = conn } @@ -355,13 +365,49 @@ func TestConnBalanced(t *testing.T) { func TestNoBackends(t *testing.T) { tester := newRouterTester(t) conn := tester.createConn() - _, err := tester.router.Route(conn) - require.ErrorIs(t, err, ErrNoInstanceToSelect) + addr := tester.simpleRoute(conn) + require.True(t, len(addr) == 0) tester.addBackends(1) tester.addConnections(10) tester.killBackends(1) - _, err = tester.router.Route(conn) - require.ErrorIs(t, err, ErrNoInstanceToSelect) + addr = tester.simpleRoute(conn) + require.True(t, len(addr) == 0) +} + +// Test that the backends returned by the BackendSelector are complete and different. +func TestSelectorReturnOrder(t *testing.T) { + tester := newRouterTester(t) + tester.addBackends(3) + selector := tester.router.GetBackendSelector() + for i := 0; i < 3; i++ { + addrs := make(map[string]struct{}, 3) + for j := 0; j < 3; j++ { + addr := selector.Next() + addrs[addr] = struct{}{} + } + // All 3 addresses are different. + require.Equal(t, 3, len(addrs)) + addr := selector.Next() + require.True(t, len(addr) == 0) + selector.Reset() + } + + tester.killBackends(1) + for i := 0; i < 2; i++ { + addr := selector.Next() + require.True(t, len(addr) > 0) + } + addr := selector.Next() + require.True(t, len(addr) == 0) + selector.Reset() + + tester.addBackends(1) + for i := 0; i < 3; i++ { + addr := selector.Next() + require.True(t, len(addr) > 0) + } + addr = selector.Next() + require.True(t, len(addr) == 0) } // Test that the backends are balanced during rolling restart. @@ -578,12 +624,14 @@ func TestConcurrency(t *testing.T) { t: t, connID: connID, } - addr, err := router.Route(conn) - if err != nil { - require.ErrorIs(t, err, ErrNoInstanceToSelect) + selector := router.GetBackendSelector() + addr := selector.Next() + if len(addr) == 0 { conn = nil continue } + err = selector.Succeed(conn) + require.NoError(t, err) conn.from = addr } else if len(conn.GetRedirectingAddr()) > 0 { // redirecting, 70% success, 20% fail, 10% close diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index ab00a0a9..cb3e3711 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -20,6 +20,7 @@ import ( "fmt" "net" "sync" + "time" "github.com/pingcap/TiProxy/lib/util/errors" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" @@ -101,7 +102,7 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili return nil } -type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) +type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { @@ -160,7 +161,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet auth.attrs = resp.Attrs // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. - backendIO, err := getBackendIO(auth, auth, resp) + backendIO, err := getBackendIO(auth, auth, resp, 5*time.Second) if err != nil { return err } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 7a00f5bb..7035c052 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -153,20 +153,41 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe return nil } -func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { +func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) { r, err := mgr.handshakeHandler.GetRouter(auth, resp) if err != nil { return nil, err } - // wait for initialize - bctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - addr, err := backoff.RetryNotifyWithData( - func() (string, error) { - addr, err := r.Route(mgr) - if !errors.Is(err, router.ErrNoInstanceToSelect) { - return addr, backoff.Permanent(err) + // Reasons to wait: + // - The TiDB instances may not be initialized yet + // - One TiDB may be just shut down and another is just started but not ready yet + bctx, cancel := context.WithTimeout(context.Background(), timeout) + selector := r.GetBackendSelector() + io, err := backoff.RetryNotifyWithData( + func() (*pnet.PacketIO, error) { + // Try to connect to all backup backends one by one. + selector.Reset() + for { + addr := selector.Next() + if len(addr) == 0 { + return nil, router.ErrNoInstanceToSelect + } + backendConn := NewBackendConnection(addr) + err := backendConn.Connect() + mgr.handshakeHandler.OnHandshake(auth, addr, err) + if err == nil { + if err = selector.Succeed(mgr); err == nil { + mgr.logger.Info("connected to backend", zap.String("addr", addr)) + mgr.backendConn = backendConn + auth.serverAddr = addr + return mgr.backendConn.PacketIO(), nil + } + // Bad luck: the backend has been recycled or shut down just after the selector returns it. + if ignoredErr := backendConn.Close(); ignoredErr != nil { + mgr.logger.Error("close backend connection failed", zap.String("addr", addr), zap.Error(ignoredErr)) + } + } } - return addr, err }, backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { @@ -174,19 +195,7 @@ func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator }, ) cancel() - if err != nil { - return nil, err - } - - mgr.logger.Info("found", zap.String("addr", addr)) - mgr.backendConn = NewBackendConnection(addr) - if err := mgr.backendConn.Connect(); err != nil { - mgr.handshakeHandler.OnHandshake(auth, addr, err) - return nil, err - } - - auth.serverAddr = addr - return mgr.backendConn.PacketIO(), nil + return io, err } // ExecuteCmd forwards messages between the client and the backend. diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 932dd2c7..c8702008 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -16,11 +16,13 @@ package backend import ( "context" + "net" "sync/atomic" "testing" "time" "github.com/pingcap/TiProxy/lib/util/errors" + "github.com/pingcap/TiProxy/lib/util/logger" "github.com/pingcap/TiProxy/lib/util/waitgroup" "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" @@ -683,3 +685,42 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) { } ts.runTests(runners) } + +func TestGetBackendIO(t *testing.T) { + listeners := make([]net.Listener, 0, 3) + addrs := make([]string, 0, 3) + for i := 0; i < 3; i++ { + listener, err := net.Listen("tcp", "0.0.0.0:0") + require.NoError(t, err) + listeners = append(listeners, listener) + addrs = append(addrs, listener.Addr().String()) + } + rt := router.NewStaticRouter(addrs) + badAddrs := make(map[string]struct{}, 3) + handler := &CustomHandshakeHandler{ + getRouter: func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + return rt, nil + }, + onHandshake: func(connContext ConnContext, s string, err error) { + if err != nil && len(s) > 0 { + badAddrs[s] = struct{}{} + } + }, + } + mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, false, false) + for i := 0; i <= len(listeners); i++ { + io, err := mgr.getBackendIO(mgr.authenticator, mgr.authenticator, nil, time.Second) + if err == nil { + require.NoError(t, io.Close()) + } + if i < len(listeners) { + require.NoError(t, err) + err = listeners[i].Close() + require.NoError(t, err) + } else { + require.ErrorIs(t, err, context.DeadlineExceeded) + } + require.True(t, len(badAddrs) <= i) + badAddrs = make(map[string]struct{}, 3) + } +} diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 2ebb9625..52f6dee0 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -17,6 +17,7 @@ package backend import ( "crypto/tls" "testing" + "time" gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/logger" @@ -63,7 +64,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(ConnContext, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { + if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(ConnContext, *Authenticator, *pnet.HandshakeResp, time.Duration) (*pnet.PacketIO, error) { return backendIO, nil }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 656907ad..71555baa 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -129,7 +129,7 @@ func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (* config.proxyConfig.frontendTLSConfig = tc.backendTLSConfig config.clientConfig.tlsConfig = tc.clientTLSConfig config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { - return router.NewStaticRouter(ts.tc.backendListener.Addr().String()), nil + return router.NewStaticRouter([]string{ts.tc.backendListener.Addr().String()}), nil } })...) ts.mb = newMockBackend(cfg.backendConfig)