From 8f2bb86478baeafd02dc70715924d896131ed799 Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 26 Dec 2022 15:13:39 +0800 Subject: [PATCH 1/8] backend: remove getBackendIO Signed-off-by: xhe --- pkg/manager/router/router.go | 373 +-------------------- pkg/manager/router/router_score.go | 370 ++++++++++++++++++++ pkg/manager/router/router_static.go | 55 +++ pkg/proxy/backend/authenticator.go | 48 ++- pkg/proxy/backend/authenticator_test.go | 32 +- pkg/proxy/backend/backend_conn_mgr.go | 50 +-- pkg/proxy/backend/backend_conn_mgr_test.go | 31 +- pkg/proxy/backend/common_test.go | 6 +- pkg/proxy/backend/handshake_handler.go | 49 +++ pkg/proxy/backend/mock_proxy_test.go | 40 +-- pkg/proxy/backend/testsuite_test.go | 4 + pkg/proxy/client/client_conn.go | 2 +- 12 files changed, 575 insertions(+), 485 deletions(-) create mode 100644 pkg/manager/router/router_score.go create mode 100644 pkg/manager/router/router_static.go diff --git a/pkg/manager/router/router.go b/pkg/manager/router/router.go index e81c9b7d..af5c3030 100644 --- a/pkg/manager/router/router.go +++ b/pkg/manager/router/router.go @@ -16,30 +16,33 @@ package router import ( "container/list" - "context" - "net/http" - "sync" "time" "github.com/pingcap/TiProxy/lib/util/errors" - "github.com/pingcap/TiProxy/lib/util/waitgroup" - "go.uber.org/zap" ) +var ( + ErrNoInstanceToSelect = errors.New("no instances to route") +) + +// ConnEventReceiver receives connection events. +type ConnEventReceiver interface { + OnRedirectSucceed(from, to string, conn RedirectableConn) error + OnRedirectFail(from, to string, conn RedirectableConn) error + OnConnClosed(addr string, conn RedirectableConn) error +} + // Router routes client connections to backends. type Router interface { + // Router will handle connection events to balance connections if possible. + ConnEventReceiver + Route(RedirectableConn) (string, error) RedirectConnections() error ConnCount() int Close() } -var _ Router = &ScoreBasedRouter{} - -var ( - ErrNoInstanceToSelect = errors.New("no instances to route") -) - type connPhase int const ( @@ -67,13 +70,6 @@ const ( redirectFailMinInterval = 3 * time.Second ) -// ConnEventReceiver receives connection events. -type ConnEventReceiver interface { - OnRedirectSucceed(from, to string, conn RedirectableConn) error - OnRedirectFail(from, to string, conn RedirectableConn) error - OnConnClosed(addr string, conn RedirectableConn) error -} - // RedirectableConn indicates a redirect-able connection. type RedirectableConn interface { SetEventReceiver(receiver ConnEventReceiver) @@ -104,344 +100,3 @@ type connWrapper struct { // Last redirect start time of this connection. lastRedirect time.Time } - -// ScoreBasedRouter is an implementation of Router interface. -// It routes a connection based on score. -type ScoreBasedRouter struct { - sync.Mutex - logger *zap.Logger - observer *BackendObserver - cancelFunc context.CancelFunc - wg waitgroup.WaitGroup - // A list of *backendWrapper. The backends are in descending order of scores. - backends *list.List -} - -// NewScoreBasedRouter creates a ScoreBasedRouter. -func NewScoreBasedRouter(logger *zap.Logger, httpCli *http.Client, fetcher BackendFetcher) (*ScoreBasedRouter, error) { - router := &ScoreBasedRouter{ - logger: logger, - backends: list.New(), - } - router.Lock() - defer router.Unlock() - observer, err := StartBackendObserver(logger.Named("observer"), router, httpCli, NewDefaultHealthCheckConfig(), fetcher) - if err != nil { - return nil, err - } - router.observer = observer - childCtx, cancelFunc := context.WithCancel(context.Background()) - router.cancelFunc = cancelFunc - router.wg.Run(func() { - router.rebalanceLoop(childCtx) - }) - return router, nil -} - -// Route implements Router.Route interface. -func (router *ScoreBasedRouter) Route(conn RedirectableConn) (string, error) { - router.Lock() - defer router.Unlock() - be := router.backends.Back() - if be == nil { - return "", ErrNoInstanceToSelect - } - backend := be.Value.(*backendWrapper) - switch backend.status { - case StatusCannotConnect, StatusSchemaOutdated: - return "", ErrNoInstanceToSelect - } - connWrapper := &connWrapper{ - RedirectableConn: conn, - phase: phaseNotRedirected, - } - router.addConn(be, connWrapper) - addBackendConnMetrics(backend.addr) - conn.SetEventReceiver(router) - return backend.addr, nil -} - -func (router *ScoreBasedRouter) removeConn(be *list.Element, ce *list.Element) { - backend := be.Value.(*backendWrapper) - conn := ce.Value.(*connWrapper) - backend.connList.Remove(ce) - delete(backend.connMap, conn.ConnectionID()) - if !router.removeBackendIfEmpty(be) { - router.adjustBackendList(be) - } -} - -func (router *ScoreBasedRouter) addConn(be *list.Element, conn *connWrapper) { - backend := be.Value.(*backendWrapper) - ce := backend.connList.PushBack(conn) - backend.connMap[conn.ConnectionID()] = ce - router.adjustBackendList(be) -} - -// adjustBackendList moves `be` after the score of `be` changes to keep the list ordered. -func (router *ScoreBasedRouter) adjustBackendList(be *list.Element) { - backend := be.Value.(*backendWrapper) - curScore := backend.score() - var mark *list.Element - for ele := be.Prev(); ele != nil; ele = ele.Prev() { - b := ele.Value.(*backendWrapper) - if b.score() >= curScore { - break - } - mark = ele - } - if mark != nil { - router.backends.MoveBefore(be, mark) - return - } - for ele := be.Next(); ele != nil; ele = ele.Next() { - b := ele.Value.(*backendWrapper) - if b.score() <= curScore { - break - } - mark = ele - } - if mark != nil { - router.backends.MoveAfter(be, mark) - } -} - -// RedirectConnections implements Router.RedirectConnections interface. -// It redirects all connections compulsively. It's only used for testing. -func (router *ScoreBasedRouter) RedirectConnections() error { - router.Lock() - defer router.Unlock() - for be := router.backends.Front(); be != nil; be = be.Next() { - backend := be.Value.(*backendWrapper) - for ce := backend.connList.Front(); ce != nil; ce = ce.Next() { - // This is only for test, so we allow it to reconnect to the same backend. - connWrapper := ce.Value.(*connWrapper) - if connWrapper.phase != phaseRedirectNotify { - connWrapper.phase = phaseRedirectNotify - connWrapper.Redirect(backend.addr) - } - } - } - return nil -} - -// forward is a hint to speed up searching. -func (router *ScoreBasedRouter) lookupBackend(addr string, forward bool) *list.Element { - if forward { - for be := router.backends.Front(); be != nil; be = be.Next() { - backend := be.Value.(*backendWrapper) - if backend.addr == addr { - return be - } - } - } else { - for be := router.backends.Back(); be != nil; be = be.Prev() { - backend := be.Value.(*backendWrapper) - if backend.addr == addr { - return be - } - } - } - return nil -} - -// OnRedirectSucceed implements ConnEventReceiver.OnRedirectSucceed interface. -func (router *ScoreBasedRouter) OnRedirectSucceed(from, to string, conn RedirectableConn) error { - router.Lock() - defer router.Unlock() - be := router.lookupBackend(to, false) - if be == nil { - return errors.WithStack(errors.Errorf("backend %s is not found in the router", to)) - } - toBackend := be.Value.(*backendWrapper) - e, ok := toBackend.connMap[conn.ConnectionID()] - if !ok { - return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), to)) - } - connWrapper := e.Value.(*connWrapper) - connWrapper.phase = phaseRedirectEnd - addMigrateMetrics(from, to, true, connWrapper.lastRedirect) - subBackendConnMetrics(from) - addBackendConnMetrics(to) - return nil -} - -// OnRedirectFail implements ConnEventReceiver.OnRedirectFail interface. -func (router *ScoreBasedRouter) OnRedirectFail(from, to string, conn RedirectableConn) error { - router.Lock() - defer router.Unlock() - be := router.lookupBackend(to, false) - if be == nil { - return errors.WithStack(errors.Errorf("backend %s is not found in the router", to)) - } - toBackend := be.Value.(*backendWrapper) - ce, ok := toBackend.connMap[conn.ConnectionID()] - if !ok { - return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), to)) - } - router.removeConn(be, ce) - - be = router.lookupBackend(from, true) - // The backend may have been removed because it's empty. Add it back. - if be == nil { - be = router.backends.PushBack(&backendWrapper{ - status: StatusCannotConnect, - addr: from, - connList: list.New(), - connMap: make(map[uint64]*list.Element), - }) - } - connWrapper := ce.Value.(*connWrapper) - connWrapper.phase = phaseRedirectFail - addMigrateMetrics(from, to, false, connWrapper.lastRedirect) - router.addConn(be, connWrapper) - return nil -} - -// OnConnClosed implements ConnEventReceiver.OnConnClosed interface. -func (router *ScoreBasedRouter) OnConnClosed(addr string, conn RedirectableConn) error { - router.Lock() - defer router.Unlock() - // Get the redirecting address in the lock, rather than letting the connection pass it in. - // While the connection closes, the router may also send a new redirection signal concurrently - // and move it to another backendWrapper. - if toAddr := conn.GetRedirectingAddr(); len(toAddr) > 0 { - addr = toAddr - } - be := router.lookupBackend(addr, true) - if be == nil { - return errors.WithStack(errors.Errorf("backend %s is not found in the router", addr)) - } - backend := be.Value.(*backendWrapper) - ce, ok := backend.connMap[conn.ConnectionID()] - if !ok { - return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), addr)) - } - router.removeConn(be, ce) - subBackendConnMetrics(addr) - return nil -} - -// OnBackendChanged implements BackendEventReceiver.OnBackendChanged interface. -func (router *ScoreBasedRouter) OnBackendChanged(backends map[string]BackendStatus) { - router.Lock() - defer router.Unlock() - for addr, status := range backends { - be := router.lookupBackend(addr, true) - if be == nil && status != StatusCannotConnect { - router.logger.Info("find new backend", zap.String("addr", addr), - zap.String("status", status.String())) - be = router.backends.PushBack(&backendWrapper{ - status: status, - addr: addr, - connList: list.New(), - connMap: make(map[uint64]*list.Element), - }) - } else { - backend := be.Value.(*backendWrapper) - router.logger.Info("update backend", zap.String("addr", addr), - zap.String("prev_status", backend.status.String()), zap.String("cur_status", status.String())) - backend.status = status - } - if !router.removeBackendIfEmpty(be) { - router.adjustBackendList(be) - } - } -} - -func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { - for { - router.rebalance(rebalanceConnsPerLoop) - select { - case <-ctx.Done(): - return - case <-time.After(rebalanceInterval): - } - } -} - -func (router *ScoreBasedRouter) rebalance(maxNum int) { - curTime := time.Now() - router.Lock() - defer router.Unlock() - for i := 0; i < maxNum; i++ { - var busiestEle *list.Element - for be := router.backends.Front(); be != nil; be = be.Next() { - backend := be.Value.(*backendWrapper) - if backend.connList.Len() > 0 { - busiestEle = be - break - } - } - if busiestEle == nil { - break - } - busiestBackend := busiestEle.Value.(*backendWrapper) - idlestEle := router.backends.Back() - idlestBackend := idlestEle.Value.(*backendWrapper) - if float64(busiestBackend.score())/float64(idlestBackend.score()+1) < rebalanceMaxScoreRatio { - break - } - var ce *list.Element - for ele := busiestBackend.connList.Front(); ele != nil; ele = ele.Next() { - conn := ele.Value.(*connWrapper) - switch conn.phase { - case phaseRedirectNotify: - // A connection cannot be redirected again when it has not finished redirecting. - continue - case phaseRedirectFail: - // If it failed recently, it will probably fail this time. - if conn.lastRedirect.Add(redirectFailMinInterval).After(curTime) { - continue - } - } - ce = ele - break - } - if ce == nil { - break - } - router.removeConn(busiestEle, ce) - conn := ce.Value.(*connWrapper) - conn.phase = phaseRedirectNotify - conn.lastRedirect = curTime - router.addConn(idlestEle, conn) - conn.Redirect(idlestBackend.addr) - } -} - -func (router *ScoreBasedRouter) removeBackendIfEmpty(be *list.Element) bool { - backend := be.Value.(*backendWrapper) - if backend.status == StatusCannotConnect && backend.connList.Len() == 0 { - router.backends.Remove(be) - return true - } - return false -} - -func (router *ScoreBasedRouter) ConnCount() int { - router.Lock() - defer router.Unlock() - j := 0 - for be := router.backends.Front(); be != nil; be = be.Next() { - backend := be.Value.(*backendWrapper) - j += backend.connList.Len() - } - return j -} - -// Close implements Router.Close interface. -func (router *ScoreBasedRouter) Close() { - router.Lock() - defer router.Unlock() - if router.cancelFunc != nil { - router.cancelFunc() - router.cancelFunc = nil - } - if router.observer != nil { - router.observer.Close() - router.observer = nil - } - router.wg.Wait() - // Router only refers to RedirectableConn, it doesn't manage RedirectableConn. -} diff --git a/pkg/manager/router/router_score.go b/pkg/manager/router/router_score.go new file mode 100644 index 00000000..72910bbe --- /dev/null +++ b/pkg/manager/router/router_score.go @@ -0,0 +1,370 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package router + +import ( + "container/list" + "context" + "net/http" + "sync" + "time" + + "github.com/pingcap/TiProxy/lib/util/errors" + "github.com/pingcap/TiProxy/lib/util/waitgroup" + "go.uber.org/zap" +) + +var _ Router = &ScoreBasedRouter{} + +// ScoreBasedRouter is an implementation of Router interface. +// It routes a connection based on score. +type ScoreBasedRouter struct { + sync.Mutex + logger *zap.Logger + observer *BackendObserver + cancelFunc context.CancelFunc + wg waitgroup.WaitGroup + // A list of *backendWrapper. The backends are in descending order of scores. + backends *list.List +} + +// NewScoreBasedRouter creates a ScoreBasedRouter. +func NewScoreBasedRouter(logger *zap.Logger, httpCli *http.Client, fetcher BackendFetcher) (*ScoreBasedRouter, error) { + router := &ScoreBasedRouter{ + logger: logger, + backends: list.New(), + } + router.Lock() + defer router.Unlock() + observer, err := StartBackendObserver(logger.Named("observer"), router, httpCli, NewDefaultHealthCheckConfig(), fetcher) + if err != nil { + return nil, err + } + router.observer = observer + childCtx, cancelFunc := context.WithCancel(context.Background()) + router.cancelFunc = cancelFunc + router.wg.Run(func() { + router.rebalanceLoop(childCtx) + }) + return router, nil +} + +// Route implements Router.Route interface. +func (router *ScoreBasedRouter) Route(conn RedirectableConn) (string, error) { + router.Lock() + defer router.Unlock() + be := router.backends.Back() + if be == nil { + return "", ErrNoInstanceToSelect + } + backend := be.Value.(*backendWrapper) + switch backend.status { + case StatusCannotConnect, StatusSchemaOutdated: + return "", ErrNoInstanceToSelect + } + connWrapper := &connWrapper{ + RedirectableConn: conn, + phase: phaseNotRedirected, + } + router.addConn(be, connWrapper) + addBackendConnMetrics(backend.addr) + conn.SetEventReceiver(router) + return backend.addr, nil +} + +func (router *ScoreBasedRouter) removeConn(be *list.Element, ce *list.Element) { + backend := be.Value.(*backendWrapper) + conn := ce.Value.(*connWrapper) + backend.connList.Remove(ce) + delete(backend.connMap, conn.ConnectionID()) + if !router.removeBackendIfEmpty(be) { + router.adjustBackendList(be) + } +} + +func (router *ScoreBasedRouter) addConn(be *list.Element, conn *connWrapper) { + backend := be.Value.(*backendWrapper) + ce := backend.connList.PushBack(conn) + backend.connMap[conn.ConnectionID()] = ce + router.adjustBackendList(be) +} + +// adjustBackendList moves `be` after the score of `be` changes to keep the list ordered. +func (router *ScoreBasedRouter) adjustBackendList(be *list.Element) { + backend := be.Value.(*backendWrapper) + curScore := backend.score() + var mark *list.Element + for ele := be.Prev(); ele != nil; ele = ele.Prev() { + b := ele.Value.(*backendWrapper) + if b.score() >= curScore { + break + } + mark = ele + } + if mark != nil { + router.backends.MoveBefore(be, mark) + return + } + for ele := be.Next(); ele != nil; ele = ele.Next() { + b := ele.Value.(*backendWrapper) + if b.score() <= curScore { + break + } + mark = ele + } + if mark != nil { + router.backends.MoveAfter(be, mark) + } +} + +// RedirectConnections implements Router.RedirectConnections interface. +// It redirects all connections compulsively. It's only used for testing. +func (router *ScoreBasedRouter) RedirectConnections() error { + router.Lock() + defer router.Unlock() + for be := router.backends.Front(); be != nil; be = be.Next() { + backend := be.Value.(*backendWrapper) + for ce := backend.connList.Front(); ce != nil; ce = ce.Next() { + // This is only for test, so we allow it to reconnect to the same backend. + connWrapper := ce.Value.(*connWrapper) + if connWrapper.phase != phaseRedirectNotify { + connWrapper.phase = phaseRedirectNotify + connWrapper.Redirect(backend.addr) + } + } + } + return nil +} + +// forward is a hint to speed up searching. +func (router *ScoreBasedRouter) lookupBackend(addr string, forward bool) *list.Element { + if forward { + for be := router.backends.Front(); be != nil; be = be.Next() { + backend := be.Value.(*backendWrapper) + if backend.addr == addr { + return be + } + } + } else { + for be := router.backends.Back(); be != nil; be = be.Prev() { + backend := be.Value.(*backendWrapper) + if backend.addr == addr { + return be + } + } + } + return nil +} + +// OnRedirectSucceed implements ConnEventReceiver.OnRedirectSucceed interface. +func (router *ScoreBasedRouter) OnRedirectSucceed(from, to string, conn RedirectableConn) error { + router.Lock() + defer router.Unlock() + be := router.lookupBackend(to, false) + if be == nil { + return errors.WithStack(errors.Errorf("backend %s is not found in the router", to)) + } + toBackend := be.Value.(*backendWrapper) + e, ok := toBackend.connMap[conn.ConnectionID()] + if !ok { + return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), to)) + } + connWrapper := e.Value.(*connWrapper) + connWrapper.phase = phaseRedirectEnd + addMigrateMetrics(from, to, true, connWrapper.lastRedirect) + subBackendConnMetrics(from) + addBackendConnMetrics(to) + return nil +} + +// OnRedirectFail implements ConnEventReceiver.OnRedirectFail interface. +func (router *ScoreBasedRouter) OnRedirectFail(from, to string, conn RedirectableConn) error { + router.Lock() + defer router.Unlock() + be := router.lookupBackend(to, false) + if be == nil { + return errors.WithStack(errors.Errorf("backend %s is not found in the router", to)) + } + toBackend := be.Value.(*backendWrapper) + ce, ok := toBackend.connMap[conn.ConnectionID()] + if !ok { + return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), to)) + } + router.removeConn(be, ce) + + be = router.lookupBackend(from, true) + // The backend may have been removed because it's empty. Add it back. + if be == nil { + be = router.backends.PushBack(&backendWrapper{ + status: StatusCannotConnect, + addr: from, + connList: list.New(), + connMap: make(map[uint64]*list.Element), + }) + } + connWrapper := ce.Value.(*connWrapper) + connWrapper.phase = phaseRedirectFail + addMigrateMetrics(from, to, false, connWrapper.lastRedirect) + router.addConn(be, connWrapper) + return nil +} + +// OnConnClosed implements ConnEventReceiver.OnConnClosed interface. +func (router *ScoreBasedRouter) OnConnClosed(addr string, conn RedirectableConn) error { + router.Lock() + defer router.Unlock() + // Get the redirecting address in the lock, rather than letting the connection pass it in. + // While the connection closes, the router may also send a new redirection signal concurrently + // and move it to another backendWrapper. + if toAddr := conn.GetRedirectingAddr(); len(toAddr) > 0 { + addr = toAddr + } + be := router.lookupBackend(addr, true) + if be == nil { + return errors.WithStack(errors.Errorf("backend %s is not found in the router", addr)) + } + backend := be.Value.(*backendWrapper) + ce, ok := backend.connMap[conn.ConnectionID()] + if !ok { + return errors.WithStack(errors.Errorf("connection %d is not found on the backend %s", conn.ConnectionID(), addr)) + } + router.removeConn(be, ce) + subBackendConnMetrics(addr) + return nil +} + +// OnBackendChanged implements BackendEventReceiver.OnBackendChanged interface. +func (router *ScoreBasedRouter) OnBackendChanged(backends map[string]BackendStatus) { + router.Lock() + defer router.Unlock() + for addr, status := range backends { + be := router.lookupBackend(addr, true) + if be == nil && status != StatusCannotConnect { + router.logger.Info("find new backend", zap.String("addr", addr), + zap.String("status", status.String())) + be = router.backends.PushBack(&backendWrapper{ + status: status, + addr: addr, + connList: list.New(), + connMap: make(map[uint64]*list.Element), + }) + } else { + backend := be.Value.(*backendWrapper) + router.logger.Info("update backend", zap.String("addr", addr), + zap.String("prev_status", backend.status.String()), zap.String("cur_status", status.String())) + backend.status = status + } + if !router.removeBackendIfEmpty(be) { + router.adjustBackendList(be) + } + } +} + +func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { + for { + router.rebalance(rebalanceConnsPerLoop) + select { + case <-ctx.Done(): + return + case <-time.After(rebalanceInterval): + } + } +} + +func (router *ScoreBasedRouter) rebalance(maxNum int) { + curTime := time.Now() + router.Lock() + defer router.Unlock() + for i := 0; i < maxNum; i++ { + var busiestEle *list.Element + for be := router.backends.Front(); be != nil; be = be.Next() { + backend := be.Value.(*backendWrapper) + if backend.connList.Len() > 0 { + busiestEle = be + break + } + } + if busiestEle == nil { + break + } + busiestBackend := busiestEle.Value.(*backendWrapper) + idlestEle := router.backends.Back() + idlestBackend := idlestEle.Value.(*backendWrapper) + if float64(busiestBackend.score())/float64(idlestBackend.score()+1) < rebalanceMaxScoreRatio { + break + } + var ce *list.Element + for ele := busiestBackend.connList.Front(); ele != nil; ele = ele.Next() { + conn := ele.Value.(*connWrapper) + switch conn.phase { + case phaseRedirectNotify: + // A connection cannot be redirected again when it has not finished redirecting. + continue + case phaseRedirectFail: + // If it failed recently, it will probably fail this time. + if conn.lastRedirect.Add(redirectFailMinInterval).After(curTime) { + continue + } + } + ce = ele + break + } + if ce == nil { + break + } + router.removeConn(busiestEle, ce) + conn := ce.Value.(*connWrapper) + conn.phase = phaseRedirectNotify + conn.lastRedirect = curTime + router.addConn(idlestEle, conn) + conn.Redirect(idlestBackend.addr) + } +} + +func (router *ScoreBasedRouter) removeBackendIfEmpty(be *list.Element) bool { + backend := be.Value.(*backendWrapper) + if backend.status == StatusCannotConnect && backend.connList.Len() == 0 { + router.backends.Remove(be) + return true + } + return false +} + +func (router *ScoreBasedRouter) ConnCount() int { + router.Lock() + defer router.Unlock() + j := 0 + for be := router.backends.Front(); be != nil; be = be.Next() { + backend := be.Value.(*backendWrapper) + j += backend.connList.Len() + } + return j +} + +// Close implements Router.Close interface. +func (router *ScoreBasedRouter) Close() { + router.Lock() + defer router.Unlock() + if router.cancelFunc != nil { + router.cancelFunc() + router.cancelFunc = nil + } + if router.observer != nil { + router.observer.Close() + router.observer = nil + } + router.wg.Wait() + // Router only refers to RedirectableConn, it doesn't manage RedirectableConn. +} diff --git a/pkg/manager/router/router_static.go b/pkg/manager/router/router_static.go new file mode 100644 index 00000000..8ff6dfb7 --- /dev/null +++ b/pkg/manager/router/router_static.go @@ -0,0 +1,55 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package router + +var _ Router = &StaticRouter{} + +type StaticRouter struct { + addr string + cnt int +} + +func NewStaticRouter(addr string) *StaticRouter { + return &StaticRouter{addr: addr} +} + +func (r *StaticRouter) Route(c RedirectableConn) (string, error) { + return r.addr, nil +} + +func (r *StaticRouter) RedirectConnections() error { + return nil +} + +func (r *StaticRouter) ConnCount() int { + r.cnt++ + return r.cnt +} + +func (r *StaticRouter) Close() { +} + +func (r *StaticRouter) OnRedirectSucceed(from, to string, conn RedirectableConn) error { + return nil +} + +func (r *StaticRouter) OnRedirectFail(from, to string, conn RedirectableConn) error { + return nil +} + +func (r *StaticRouter) OnConnClosed(addr string, conn RedirectableConn) error { + r.cnt-- + return nil +} diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index d6d6a41b..21e614ff 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -15,13 +15,17 @@ package backend import ( + "context" "crypto/tls" "encoding/binary" "fmt" "net" "sync" + "time" + "github.com/cenkalti/backoff/v4" "github.com/pingcap/TiProxy/lib/util/errors" + "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/util/hack" @@ -101,8 +105,9 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili return nil } -func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, - getBackend backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { +func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, handshakeHandler HandshakeHandler, frontendTLSConfig, backendTLSConfig *tls.Config) error { + auth := mgr.authenticator + clientIO.ResetSequence() auth.clientAddr = clientIO.SourceAddr().String() @@ -157,12 +162,43 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet auth.collation = resp.Collation auth.attrs = resp.Attrs - backendIO, err := getBackend(auth, auth, resp) - if err != nil { - return err - } + if backendIO == nil { + r, err := handshakeHandler.GetRouter(auth, resp) + if err != nil { + return err + } + // wait for initialize + bctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + addr, err := backoff.RetryNotifyWithData( + func() (string, error) { + addr, err := r.Route(mgr) + if !errors.Is(err, router.ErrNoInstanceToSelect) { + return addr, backoff.Permanent(err) + } + return addr, err + }, + backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), + func(err error, d time.Duration) { + mgr.handshakeHandler.OnHandshake(auth, "", err) + }, + ) + cancel() + if err != nil { + return err + } + mgr.logger.Info("found", zap.String("addr", addr)) + mgr.backendConn = NewBackendConnection(addr) + if err := mgr.backendConn.Connect(); err != nil { + mgr.handshakeHandler.OnHandshake(auth, addr, err) + return err + } + + auth.serverAddr = addr + backendIO = mgr.backendConn.PacketIO() + } backendIO.ResetSequence() + // write proxy header if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil { return err diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 320b1578..71a156cf 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -212,20 +212,30 @@ func TestSecondHandshake(t *testing.T) { func TestCustomAuth(t *testing.T) { tc := newTCPConnSuite(t) - handler := &CustomHandshakeHandler{ - outUsername: "rewritten_user", - outAttrs: map[string]string{"key": "value"}, - outCapability: SupportedServerCapabilities & ^pnet.ClientDeprecateEOF, - } + reUser := "rewritten_user" + reAttrs := map[string]string{"key": "value"} + reCap := SupportedServerCapabilities & ^pnet.ClientDeprecateEOF + inUser := "" + inAddr := "" ts, clean := newTestSuite(t, tc, func(cfg *testConfig) { - cfg.proxyConfig.handler = handler + handler := cfg.proxyConfig.handler + handler.handleHandshakeResp = func(ctx ConnContext, resp *pnet.HandshakeResp) error { + inUser = resp.User + inAddr = ctx.ClientAddr() + resp.User = reUser + resp.Attrs = reAttrs + return nil + } + handler.getCapability = func() pnet.Capability { + return reCap + } }) checker := func() { - require.Equal(t, ts.mc.username, handler.inUsername) - require.Equal(t, handler.outUsername, ts.mb.username) - require.Equal(t, handler.outAttrs, ts.mb.attrs) - require.Equal(t, handler.outCapability&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF) - host, _, err := net.SplitHostPort(handler.inAddr) + require.Equal(t, ts.mc.username, inUser) + require.Equal(t, reUser, ts.mb.username) + require.Equal(t, reAttrs, ts.mb.attrs) + require.Equal(t, reCap&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF) + host, _, err := net.SplitHostPort(inAddr) require.NoError(t, err) require.Equal(t, host, "::1") } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 9dc13b9a..1e42d76e 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -26,7 +26,6 @@ import ( "time" "unsafe" - "github.com/cenkalti/backoff/v4" gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/lib/util/waitgroup" @@ -59,8 +58,6 @@ type redirectResult struct { to string } -type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) - // BackendConnManager migrates a session from one BackendConnection to another. // // The signal processing goroutine tries to migrate the session once it receives a signal. @@ -89,7 +86,6 @@ type BackendConnManager struct { cancelFunc context.CancelFunc backendConn *BackendConnection handshakeHandler HandshakeHandler - getBackendIO backendIOGetter connectionID uint64 } @@ -110,42 +106,6 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler signalReceived: make(chan struct{}, 1), redirectResCh: make(chan *redirectResult, 1), } - mgr.getBackendIO = func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { - r, err := handshakeHandler.GetRouter(ctx, resp) - if err != nil { - return nil, err - } - // wait for initialize - bctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - addr, err := backoff.RetryNotifyWithData( - func() (string, error) { - addr, err := r.Route(mgr) - if !errors.Is(err, router.ErrNoInstanceToSelect) { - return addr, backoff.Permanent(err) - } - return addr, err - }, - backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), - func(err error, d time.Duration) { - mgr.handshakeHandler.OnHandshake(ctx, "", err) - }, - ) - cancel() - if err != nil { - return nil, err - } - - mgr.logger.Info("found", zap.String("addr", addr)) - mgr.backendConn = NewBackendConnection(addr) - if err := mgr.backendConn.Connect(); err != nil { - mgr.handshakeHandler.OnHandshake(ctx, addr, err) - return nil, err - } - - auth.serverAddr = addr - backendIO := mgr.backendConn.PacketIO() - return backendIO, nil - } return mgr } @@ -156,17 +116,11 @@ func (mgr *BackendConnManager) ConnectionID() uint64 { } // Connect connects to the first backend and then start watching redirection signals. -func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.PacketIO, getBackendIO backendIOGetter, - frontendTLSConfig, backendTLSConfig *tls.Config) error { +func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.PacketIO, frontendTLSConfig, backendTLSConfig *tls.Config) error { mgr.processLock.Lock() defer mgr.processLock.Unlock() - if getBackendIO == nil { - getBackendIO = mgr.getBackendIO - } - - err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler, - getBackendIO, frontendTLSConfig, backendTLSConfig) + err := mgr.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, nil, mgr.handshakeHandler, frontendTLSConfig, backendTLSConfig) mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, err) if err != nil { return err diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 8f5e99b8..9293fa11 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -116,20 +116,9 @@ func newBackendMgrTester(t *testing.T, cfg ...cfgOverrider) *backendMgrTester { return tester } -func (ts *backendMgrTester) getBackendIO(ctx ConnContext, auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) { - addr := ts.tc.backendListener.Addr().String() - ts.mp.backendConn = NewBackendConnection(addr) - if err := ts.mp.backendConn.Connect(); err != nil { - return nil, err - } - backendIO := ts.mp.backendConn.PacketIO() - auth.serverAddr = addr - return backendIO, nil -} - // Define some common runners here to reduce code redundancy. func (ts *backendMgrTester) firstHandshake4Proxy(clientIO, backendIO *pnet.PacketIO) error { - err := ts.mp.Connect(context.Background(), clientIO, ts.getBackendIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) + err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) require.NoError(ts.t, err) mer := newMockEventReceiver() ts.mp.SetEventReceiver(mer) @@ -371,7 +360,7 @@ func TestConnectFail(t *testing.T) { { client: ts.mc.authenticate, proxy: func(clientIO, backendIO *pnet.PacketIO) error { - return ts.mp.Connect(context.Background(), clientIO, ts.getBackendIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) + return ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) }, backend: func(_ *pnet.PacketIO) error { conn, err := ts.tc.backendListener.Accept() @@ -548,14 +537,16 @@ func TestCloseWhileRedirect(t *testing.T) { } func TestCustomHandshake(t *testing.T) { - handler := &CustomHandshakeHandler{ - outUsername: "rewritten_user", - outAttrs: map[string]string{"key": "value"}, - outCapability: SupportedServerCapabilities & ^pnet.ClientDeprecateEOF, - } ts := newBackendMgrTester(t, func(cfg *testConfig) { - //cfg.clientConfig.capability = handler.outCapability - cfg.proxyConfig.handler = handler + handler := cfg.proxyConfig.handler + handler.handleHandshakeResp = func(ctx ConnContext, resp *pnet.HandshakeResp) error { + resp.User = "rewritten_user" + resp.Attrs = map[string]string{"key": "value"} + return nil + } + handler.getCapability = func() pnet.Capability { + return SupportedServerCapabilities & ^pnet.ClientDeprecateEOF + } }) runners := []runner{ // 1st handshake diff --git a/pkg/proxy/backend/common_test.go b/pkg/proxy/backend/common_test.go index 6e3c89c4..026ff719 100644 --- a/pkg/proxy/backend/common_test.go +++ b/pkg/proxy/backend/common_test.go @@ -20,8 +20,8 @@ import ( "testing" "github.com/pingcap/TiProxy/lib/util/security" + "github.com/pingcap/TiProxy/lib/util/waitgroup" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" - "github.com/pingcap/tidb/util" "github.com/stretchr/testify/require" ) @@ -56,7 +56,7 @@ func newTCPConnSuite(t *testing.T) *tcpConnSuite { } func (tc *tcpConnSuite) newConn(t *testing.T, withBackend bool) func() { - var wg util.WaitGroupWrapper + var wg waitgroup.WaitGroup if withBackend { wg.Run(func() { conn, err := tc.backendListener.Accept() @@ -92,7 +92,7 @@ func (tc *tcpConnSuite) newConn(t *testing.T, withBackend bool) func() { } func (tc *tcpConnSuite) run(clientRunner, backendRunner func(*pnet.PacketIO) error, proxyRunner func(*pnet.PacketIO, *pnet.PacketIO) error) (cerr, berr, perr error) { - var wg util.WaitGroupWrapper + var wg waitgroup.WaitGroup if clientRunner != nil { wg.Run(func() { cerr = clientRunner(tc.clientIO) diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 802fa351..5d03cfa5 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -74,3 +74,52 @@ func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext) error { func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability { return SupportedServerCapabilities } + +type CustomHandshakeHandler struct { + getRouter func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) + onHandshake func(ConnContext, string, error) + onConnClose func(ConnContext) error + handleHandshakeResp func(ctx ConnContext, resp *pnet.HandshakeResp) error + getCapability func() pnet.Capability +} + +func (h *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + if h.getRouter != nil { + return h.getRouter(ctx, resp) + } + return nil, errors.New("no router") +} + +func (h *CustomHandshakeHandler) OnHandshake(ctx ConnContext, addr string, err error) { + if h.onHandshake != nil { + h.onHandshake(ctx, addr, err) + } +} + +func (h *CustomHandshakeHandler) OnConnClose(ctx ConnContext) error { + if h.onConnClose != nil { + return h.onConnClose(ctx) + } + return nil +} + +func (h *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error { + if h.handleHandshakeResp != nil { + return h.handleHandshakeResp(ctx, resp) + } + return nil + /* + h.inUsername = resp.User + resp.User = h.outUsername + h.inAddr = ctx.ClientAddr() + resp.Attrs = h.outAttrs + return nil + */ +} + +func (h *CustomHandshakeHandler) GetCapability() pnet.Capability { + if h.getCapability != nil { + return h.getCapability() + } + return SupportedServerCapabilities +} diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index a7a50c43..64cbd1bd 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -20,7 +20,6 @@ import ( gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/logger" - "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "go.uber.org/zap" ) @@ -28,7 +27,7 @@ import ( type proxyConfig struct { frontendTLSConfig *tls.Config backendTLSConfig *tls.Config - handler HandshakeHandler + handler *CustomHandshakeHandler sessionToken string capability uint32 waitRedirect bool @@ -36,7 +35,7 @@ type proxyConfig struct { func newProxyConfig() *proxyConfig { return &proxyConfig{ - handler: NewDefaultHandshakeHandler(nil), + handler: &CustomHandshakeHandler{}, capability: defaultTestBackendCapability, sessionToken: mockToken, } @@ -64,9 +63,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(ConnContext, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { - return backendIO, nil - }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { + if err := mp.handshakeFirstTime(mp.logger, clientIO, backendIO, mp.handshakeHandler, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err } mp.cmdProcessor.capability = mp.authenticator.capability @@ -98,34 +95,3 @@ func (mp *mockProxy) directQuery(_, backendIO *pnet.PacketIO) error { mp.rs = rs return err } - -type CustomHandshakeHandler struct { - inUsername string - inAddr string - outCapability pnet.Capability - outUsername string - outAttrs map[string]string -} - -func (handler *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { - return nil, nil -} - -func (handler *CustomHandshakeHandler) OnHandshake(ctx ConnContext, _ string, _ error) { -} - -func (handler *CustomHandshakeHandler) OnConnClose(ctx ConnContext) error { - return nil -} - -func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error { - handler.inUsername = resp.User - resp.User = handler.outUsername - handler.inAddr = ctx.ClientAddr() - resp.Attrs = handler.outAttrs - return nil -} - -func (handler *CustomHandshakeHandler) GetCapability() pnet.Capability { - return handler.outCapability -} diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 3ed52df7..32384da3 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -19,6 +19,7 @@ import ( "strings" "testing" + "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" @@ -127,6 +128,9 @@ func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (* config.proxyConfig.backendTLSConfig = tc.clientTLSConfig config.proxyConfig.frontendTLSConfig = tc.backendTLSConfig config.clientConfig.tlsConfig = tc.clientTLSConfig + config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + return router.NewStaticRouter(ts.tc.backendListener.Addr().String()), nil + } })...) ts.mb = newMockBackend(cfg.backendConfig) ts.mp = newMockProxy(t, cfg.proxyConfig) diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index 81b9b871..07a37a69 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -61,7 +61,7 @@ func (cc *ClientConnection) Run(ctx context.Context) { var err error var msg string - if err = cc.connMgr.Connect(ctx, cc.pkt, nil, cc.frontendTLSConfig, cc.backendTLSConfig); err != nil { + if err = cc.connMgr.Connect(ctx, cc.pkt, cc.frontendTLSConfig, cc.backendTLSConfig); err != nil { msg = "new connection failed" goto clean } From 99417525e76aab992c30d708196463815b7f8396 Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 26 Dec 2022 15:20:31 +0800 Subject: [PATCH 2/8] *: remove useless comments Signed-off-by: xhe --- pkg/proxy/backend/handshake_handler.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 5d03cfa5..84407907 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -108,13 +108,6 @@ func (h *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet return h.handleHandshakeResp(ctx, resp) } return nil - /* - h.inUsername = resp.User - resp.User = h.outUsername - h.inAddr = ctx.ClientAddr() - resp.Attrs = h.outAttrs - return nil - */ } func (h *CustomHandshakeHandler) GetCapability() pnet.Capability { From e7c7788fd2ca1a547e9ab4cb6e0aa51c7a3e2525 Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 26 Dec 2022 15:23:14 +0800 Subject: [PATCH 3/8] *: add comments Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 21e614ff..56ad57af 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -162,6 +162,7 @@ func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, auth.collation = resp.Collation auth.attrs = resp.Attrs + // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. if backendIO == nil { r, err := handshakeHandler.GetRouter(auth, resp) if err != nil { From 894917f5b31e03206236eb5fe488fa3c40aa9853 Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 26 Dec 2022 15:35:34 +0800 Subject: [PATCH 4/8] *: rename to something understandable Signed-off-by: xhe --- pkg/proxy/backend/backend_conn_mgr_test.go | 2 +- pkg/proxy/backend/common_test.go | 6 +++--- pkg/proxy/backend/testsuite_test.go | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 9293fa11..975a8f65 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -94,7 +94,7 @@ type backendMgrTester struct { func newBackendMgrTester(t *testing.T, cfg ...cfgOverrider) *backendMgrTester { tc := newTCPConnSuite(t) cfg = append(cfg, func(cfg *testConfig) { - cfg.testSuiteConfig.initBackendConn = false + cfg.testSuiteConfig.enableRouteLogic = true }) ts, clean := newTestSuite(t, tc, cfg...) tester := &backendMgrTester{ diff --git a/pkg/proxy/backend/common_test.go b/pkg/proxy/backend/common_test.go index 026ff719..17c69e31 100644 --- a/pkg/proxy/backend/common_test.go +++ b/pkg/proxy/backend/common_test.go @@ -55,9 +55,9 @@ func newTCPConnSuite(t *testing.T) *tcpConnSuite { return r } -func (tc *tcpConnSuite) newConn(t *testing.T, withBackend bool) func() { +func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() { var wg waitgroup.WaitGroup - if withBackend { + if !enableRoute { wg.Run(func() { conn, err := tc.backendListener.Accept() require.NoError(t, err) @@ -65,7 +65,7 @@ func (tc *tcpConnSuite) newConn(t *testing.T, withBackend bool) func() { }) } wg.Run(func() { - if withBackend { + if !enableRoute { backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String()) require.NoError(t, err) tc.proxyBIO = pnet.NewPacketIO(backendConn) diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 32384da3..656907ad 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -110,13 +110,13 @@ type testSuite struct { } type testSuiteConfig struct { - initBackendConn bool + // When true, routing logic in handshakeFirstTime is enabled. + // When false, a manual created backendIO is passed to handler to skip the routing logic. + enableRouteLogic bool } func newTestSuiteConfig() *testSuiteConfig { - return &testSuiteConfig{ - initBackendConn: true, - } + return &testSuiteConfig{} } type checker func(t *testing.T, ts *testSuite) @@ -137,7 +137,7 @@ func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (* ts.mc = newMockClient(cfg.clientConfig) ts.tc = tc ts.testSuiteConfig = cfg.testSuiteConfig - clean := tc.newConn(t, ts.initBackendConn) + clean := tc.newConn(t, ts.enableRouteLogic) return ts, clean } From fef6b0f346f2ad516591612ee850910ec1f3f40d Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 26 Dec 2022 18:16:17 +0800 Subject: [PATCH 5/8] *: split it into a standalone function Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 39 +++------------------------ pkg/proxy/backend/backend_conn_mgr.go | 39 ++++++++++++++++++++++++++- pkg/proxy/backend/mock_proxy_test.go | 2 +- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 56ad57af..28653eea 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -15,17 +15,13 @@ package backend import ( - "context" "crypto/tls" "encoding/binary" "fmt" "net" "sync" - "time" - "github.com/cenkalti/backoff/v4" "github.com/pingcap/TiProxy/lib/util/errors" - "github.com/pingcap/TiProxy/pkg/manager/router" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/util/hack" @@ -105,7 +101,7 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili return nil } -func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, handshakeHandler HandshakeHandler, frontendTLSConfig, backendTLSConfig *tls.Config) error { +func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, frontendTLSConfig, backendTLSConfig *tls.Config) error { auth := mgr.authenticator clientIO.ResetSequence() @@ -154,7 +150,7 @@ func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, auth.capability = commonCaps.Uint32() resp := pnet.ParseHandshakeResponse(pkt) - if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil { + if err = mgr.handshakeHandler.HandleHandshakeResp(auth, resp); err != nil { return err } auth.user = resp.User @@ -164,39 +160,10 @@ func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. if backendIO == nil { - r, err := handshakeHandler.GetRouter(auth, resp) + backendIO, err = mgr.getBackendIO(resp) if err != nil { return err } - // wait for initialize - bctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - addr, err := backoff.RetryNotifyWithData( - func() (string, error) { - addr, err := r.Route(mgr) - if !errors.Is(err, router.ErrNoInstanceToSelect) { - return addr, backoff.Permanent(err) - } - return addr, err - }, - backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), - func(err error, d time.Duration) { - mgr.handshakeHandler.OnHandshake(auth, "", err) - }, - ) - cancel() - if err != nil { - return err - } - - mgr.logger.Info("found", zap.String("addr", addr)) - mgr.backendConn = NewBackendConnection(addr) - if err := mgr.backendConn.Connect(); err != nil { - mgr.handshakeHandler.OnHandshake(auth, addr, err) - return err - } - - auth.serverAddr = addr - backendIO = mgr.backendConn.PacketIO() } backendIO.ResetSequence() diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 1e42d76e..91b9ee2e 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/cenkalti/backoff/v4" gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/lib/util/waitgroup" @@ -120,7 +121,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.processLock.Lock() defer mgr.processLock.Unlock() - err := mgr.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, nil, mgr.handshakeHandler, frontendTLSConfig, backendTLSConfig) + err := mgr.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, nil, frontendTLSConfig, backendTLSConfig) mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, err) if err != nil { return err @@ -135,6 +136,42 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe return nil } +func (mgr *BackendConnManager) getBackendIO(resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { + r, err := mgr.handshakeHandler.GetRouter(mgr.authenticator, resp) + if err != nil { + return nil, err + } + // wait for initialize + bctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + addr, err := backoff.RetryNotifyWithData( + func() (string, error) { + addr, err := r.Route(mgr) + if !errors.Is(err, router.ErrNoInstanceToSelect) { + return addr, backoff.Permanent(err) + } + return addr, err + }, + backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), + func(err error, d time.Duration) { + mgr.handshakeHandler.OnHandshake(mgr.authenticator, "", err) + }, + ) + cancel() + if err != nil { + return nil, err + } + + mgr.logger.Info("found", zap.String("addr", addr)) + mgr.backendConn = NewBackendConnection(addr) + if err := mgr.backendConn.Connect(); err != nil { + mgr.handshakeHandler.OnHandshake(mgr.authenticator, addr, err) + return nil, err + } + + mgr.authenticator.serverAddr = addr + return mgr.backendConn.PacketIO(), nil +} + // ExecuteCmd forwards messages between the client and the backend. // If it finds that the session is ready for redirection, it migrates the session. func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, clientIO *pnet.PacketIO) error { diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 64cbd1bd..f602ea23 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -63,7 +63,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.handshakeFirstTime(mp.logger, clientIO, backendIO, mp.handshakeHandler, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { + if err := mp.handshakeFirstTime(mp.logger, clientIO, backendIO, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err } mp.cmdProcessor.capability = mp.authenticator.capability From 6666f511a7823d03cb4d7ca32bf678736d25d0fa Mon Sep 17 00:00:00 2001 From: xhe Date: Tue, 27 Dec 2022 17:02:21 +0800 Subject: [PATCH 6/8] *: handshakeFirstTime as a method of Authenticator Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 10 ++++++---- pkg/proxy/backend/backend_conn_mgr.go | 12 ++++++------ pkg/proxy/backend/mock_proxy_test.go | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 28653eea..923fef1d 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -101,10 +101,12 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili return nil } -func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, frontendTLSConfig, backendTLSConfig *tls.Config) error { - auth := mgr.authenticator +type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) +func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, handshakeHandler HandshakeHandler, + getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() + auth.serverAddr = backendIO.SourceAddr().String() auth.clientAddr = clientIO.SourceAddr().String() proxyCapability := auth.supportedServerCapabilities @@ -150,7 +152,7 @@ func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, auth.capability = commonCaps.Uint32() resp := pnet.ParseHandshakeResponse(pkt) - if err = mgr.handshakeHandler.HandleHandshakeResp(auth, resp); err != nil { + if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil { return err } auth.user = resp.User @@ -160,7 +162,7 @@ func (mgr *BackendConnManager) handshakeFirstTime(logger *zap.Logger, clientIO, // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. if backendIO == nil { - backendIO, err = mgr.getBackendIO(resp) + backendIO, err = getBackendIO(auth, auth, resp) if err != nil { return err } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 91b9ee2e..1efebd25 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -121,7 +121,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.processLock.Lock() defer mgr.processLock.Unlock() - err := mgr.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, nil, frontendTLSConfig, backendTLSConfig) + err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, nil, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, err) if err != nil { return err @@ -136,8 +136,8 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe return nil } -func (mgr *BackendConnManager) getBackendIO(resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { - r, err := mgr.handshakeHandler.GetRouter(mgr.authenticator, resp) +func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { + r, err := mgr.handshakeHandler.GetRouter(auth, resp) if err != nil { return nil, err } @@ -153,7 +153,7 @@ func (mgr *BackendConnManager) getBackendIO(resp *pnet.HandshakeResp) (*pnet.Pac }, backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { - mgr.handshakeHandler.OnHandshake(mgr.authenticator, "", err) + mgr.handshakeHandler.OnHandshake(auth, "", err) }, ) cancel() @@ -164,11 +164,11 @@ func (mgr *BackendConnManager) getBackendIO(resp *pnet.HandshakeResp) (*pnet.Pac mgr.logger.Info("found", zap.String("addr", addr)) mgr.backendConn = NewBackendConnection(addr) if err := mgr.backendConn.Connect(); err != nil { - mgr.handshakeHandler.OnHandshake(mgr.authenticator, addr, err) + mgr.handshakeHandler.OnHandshake(auth, addr, err) return nil, err } - mgr.authenticator.serverAddr = addr + auth.serverAddr = addr return mgr.backendConn.PacketIO(), nil } diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index f602ea23..07c26ea1 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -63,7 +63,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.handshakeFirstTime(mp.logger, clientIO, backendIO, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { + if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, backendIO, mp.handshakeHandler, nil, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err } mp.cmdProcessor.capability = mp.authenticator.capability From 6fc923f69f77221023e888738168062abef47ec8 Mon Sep 17 00:00:00 2001 From: xhe Date: Tue, 27 Dec 2022 17:06:21 +0800 Subject: [PATCH 7/8] *: fix merge error Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 923fef1d..0b5c3ba6 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -106,7 +106,6 @@ type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.Hands func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, handshakeHandler HandshakeHandler, getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() - auth.serverAddr = backendIO.SourceAddr().String() auth.clientAddr = clientIO.SourceAddr().String() proxyCapability := auth.supportedServerCapabilities From d37b83c58d0983f53e419edf76dd372f3ea8f4f7 Mon Sep 17 00:00:00 2001 From: xhe Date: Tue, 27 Dec 2022 17:09:35 +0800 Subject: [PATCH 8/8] *: more original Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 10 ++++------ pkg/proxy/backend/backend_conn_mgr.go | 2 +- pkg/proxy/backend/mock_proxy_test.go | 4 +++- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 0b5c3ba6..ab00a0a9 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -103,7 +103,7 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) -func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, handshakeHandler HandshakeHandler, +func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler, getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error { clientIO.ResetSequence() auth.clientAddr = clientIO.SourceAddr().String() @@ -160,11 +160,9 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO, back auth.attrs = resp.Attrs // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. - if backendIO == nil { - backendIO, err = getBackendIO(auth, auth, resp) - if err != nil { - return err - } + backendIO, err := getBackendIO(auth, auth, resp) + if err != nil { + return err } backendIO.ResetSequence() diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index e18b1d3f..3f6f9a82 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -139,7 +139,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.processLock.Lock() defer mgr.processLock.Unlock() - err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, nil, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) + err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) mgr.handshakeHandler.OnHandshake(mgr.authenticator, mgr.authenticator.serverAddr, err) if err != nil { return err diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 07c26ea1..2ebb9625 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -63,7 +63,9 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, backendIO, mp.handshakeHandler, nil, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { + if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(ConnContext, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { + return backendIO, nil + }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err } mp.cmdProcessor.capability = mp.authenticator.capability