Skip to content

Commit

Permalink
backend: add quit source to ConnContext (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Mar 6, 2023
1 parent cfa03e5 commit 14429a3
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 30 deletions.
71 changes: 58 additions & 13 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ type BackendConnManager struct {
handshakeHandler HandshakeHandler
ctxmap sync.Map
connectionID uint64
quitSource ErrorSource
}

// NewBackendConnManager creates a BackendConnManager.
Expand All @@ -151,6 +152,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler
// There are 2 types of signals, which may be sent concurrently.
signalReceived: make(chan signalType, signalTypeNums),
redirectResCh: make(chan *redirectResult, 1),
quitSource: SrcClientQuit,
}
return mgr
}
Expand All @@ -170,11 +172,14 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe

mgr.clientIO = clientIO
err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig)
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err)
if err != nil {
mgr.setQuitSourceByErr(err)
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err)
WriteUserError(clientIO, err, mgr.logger)
return err
}
mgr.resetQuitSource()
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil)

mgr.cmdProcessor.capability = mgr.authenticator.capability
childCtx, cancelFunc := context.WithCancel(ctx)
Expand Down Expand Up @@ -233,14 +238,15 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
backendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr()))
backendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(addr, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))
mgr.backendIO.Store(backendIO)
mgr.setKeepAlive(mgr.config.HealthyKeepAlive)
return backendIO, nil
},
backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx),
func(err error, d time.Duration) {
origErr = err
mgr.setQuitSourceByErr(err)
mgr.handshakeHandler.OnHandshake(cctx, addr, err)
},
)
Expand All @@ -264,9 +270,13 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato

// ExecuteCmd forwards messages between the client and the backend.
// If it finds that the session is ready for redirection, it migrates the session.
func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) error {
func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (err error) {
defer func() {
mgr.setQuitSourceByErr(err)
}()
if len(request) < 1 {
return mysql.ErrMalformPacket
err = mysql.ErrMalformPacket
return
}
cmd := request[0]
startTime := time.Now()
Expand All @@ -275,25 +285,26 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e

switch mgr.closeStatus.Load() {
case statusClosing, statusClosed:
return nil
return
}
defer mgr.resetCheckBackendTicker()
waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil
holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), waitingRedirect)
var holdRequest bool
holdRequest, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), waitingRedirect)
if !holdRequest {
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
}
if err != nil {
if !IsMySQLError(err) {
return err
return
} else {
mgr.logger.Debug("got a mysql error", zap.Error(err))
}
}
if err == nil {
switch cmd {
case mysql.ComQuit:
return nil
return
case mysql.ComSetOption:
val := binary.LittleEndian.Uint16(request[1:])
switch val {
Expand All @@ -304,12 +315,13 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
mgr.authenticator.capability &^= mysql.ClientMultiStatements
mgr.cmdProcessor.capability &^= mysql.ClientMultiStatements
default:
return errors.Errorf("unrecognized set_option value:%d", val)
err = errors.Errorf("unrecognized set_option value:%d", val)
return
}
case mysql.ComChangeUser:
username, db := pnet.ParseChangeUser(request)
mgr.authenticator.changeUser(username, db)
return nil
return
}
}
// Even if it meets an MySQL error, it may have changed the status, such as when executing multi-statements.
Expand All @@ -320,7 +332,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
_, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), false)
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
if err != nil && !IsMySQLError(err) {
return err
return
}
} else if mgr.closeStatus.Load() == statusNotifyClose {
mgr.tryGracefulClose(ctx)
Expand All @@ -329,7 +341,8 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) e
}
}
// Ignore MySQL errors, only return unexpected errors.
return nil
err = nil
return
}

// SetEventReceiver implements RedirectableConn.SetEventReceiver interface.
Expand Down Expand Up @@ -428,6 +441,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
// If the backend connection is closed, also close the client connection.
// Otherwise, if the client is idle, the mgr will keep retrying.
if errors.Is(rs.err, net.ErrClosed) || pnet.IsDisconnectError(rs.err) || errors.Is(rs.err, os.ErrDeadlineExceeded) {
mgr.quitSource = SrcBackendQuit
if ignoredErr := mgr.clientIO.GracefulClose(); ignoredErr != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(ignoredErr))
}
Expand All @@ -438,17 +452,20 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
return
}

defer mgr.resetQuitSource()
var cn net.Conn
cn, rs.err = net.DialTimeout("tcp", rs.to, DialTimeout)
if rs.err != nil {
mgr.quitSource = SrcBackendQuit
mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err)
return
}
newBackendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()))
newBackendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))

if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil {
rs.err = mgr.initSessionStates(newBackendIO, sessionStates)
} else {
mgr.setQuitSourceByErr(rs.err)
mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err)
}
if rs.err != nil {
Expand Down Expand Up @@ -538,6 +555,7 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) {
if !mgr.cmdProcessor.finishedTxn() {
return
}
mgr.quitSource = SrcProxyQuit
// Closing clientIO will cause the whole connection to be closed.
if err := mgr.clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err))
Expand All @@ -557,6 +575,7 @@ func (mgr *BackendConnManager) checkBackendActive() {
if !backendIO.IsPeerActive() {
mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client", mgr.clientIO.RemoteAddr()),
zap.Stringer("backend", backendIO.RemoteAddr()))
mgr.quitSource = SrcBackendQuit
if err := mgr.clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err))
}
Expand Down Expand Up @@ -602,6 +621,10 @@ func (mgr *BackendConnManager) ClientOutBytes() uint64 {
return mgr.clientIO.OutBytes()
}

func (mgr *BackendConnManager) QuitSource() ErrorSource {
return mgr.quitSource
}

func (mgr *BackendConnManager) SetValue(key, val any) {
mgr.ctxmap.Store(key, val)
}
Expand Down Expand Up @@ -675,3 +698,25 @@ func (mgr *BackendConnManager) setKeepAlive(cfg config.KeepAlive) {
mgr.logger.Warn("failed to set keepalive", zap.Error(err), zap.Stringer("backend", backendIO.RemoteAddr()))
}
}

// quitSource will be read by OnHandshake and OnConnClose, so setQuitSourceByErr should be called before them.
func (mgr *BackendConnManager) setQuitSourceByErr(err error) {
// Do not update the source if err is nil. It may be already be set.
if err == nil {
return
}
if errors.Is(err, ErrBackendConn) {
mgr.quitSource = SrcBackendQuit
} else if IsMySQLError(err) {
mgr.quitSource = SrcClientErr
} else if !errors.Is(err, ErrClientConn) {
mgr.quitSource = SrcProxyErr
}
}

func (mgr *BackendConnManager) resetQuitSource() {
// SrcClientQuit is by default.
// Sometimes ErrClientConn is caused by GracefulClose and the quitSource is already set.
// Error maybe set during handshake for OnHandshake. If handshake finally succeeds, we reset it.
mgr.quitSource = SrcClientQuit
}
57 changes: 52 additions & 5 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ func TestNormalRedirect(t *testing.T) {
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
require.Equal(t, SrcClientQuit, ts.mp.QuitSource())
return nil
},
backend: ts.redirectSucceed4Backend,
Expand Down Expand Up @@ -352,6 +353,7 @@ func TestRedirectInTxn(t *testing.T) {
require.NoError(t, err)
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail)
require.Equal(t, backend1, ts.mp.backendIO.Load())
require.Equal(t, SrcClientQuit, ts.mp.QuitSource())
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -388,6 +390,12 @@ func TestConnectFail(t *testing.T) {
return ts.mb.authenticate(ts.tc.backendIO)
},
},
{
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
require.Equal(t, SrcClientErr, ts.mp.QuitSource())
return nil
},
},
}
ts.runTests(runners)
}
Expand Down Expand Up @@ -499,6 +507,7 @@ func TestSpecialCmds(t *testing.T) {
ts.mp.Redirect(ts.tc.backendListener.Addr().String())
ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventSucceed)
require.NotEqual(t, backend1, ts.mp.backendIO.Load())
require.Equal(t, SrcClientQuit, ts.mp.QuitSource())
return nil
},
backend: func(packetIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -599,6 +608,12 @@ func TestCustomHandshake(t *testing.T) {
},
backend: ts.redirectSucceed4Backend,
},
{
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
require.Equal(t, SrcClientQuit, ts.mp.QuitSource())
return nil
},
},
}
ts.runTests(runners)
}
Expand All @@ -623,6 +638,12 @@ func TestGracefulCloseWhenIdle(t *testing.T) {
{
proxy: ts.checkConnClosed4Proxy,
},
{
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
require.Equal(t, SrcProxyQuit, ts.mp.QuitSource())
return nil
},
},
}
ts.runTests(runners)
}
Expand Down Expand Up @@ -661,6 +682,12 @@ func TestGracefulCloseWhenActive(t *testing.T) {
{
proxy: ts.checkConnClosed4Proxy,
},
{
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
require.Equal(t, SrcProxyQuit, ts.mp.QuitSource())
return nil
},
},
}
ts.runTests(runners)
}
Expand All @@ -685,30 +712,39 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) {
{
proxy: ts.checkConnClosed4Proxy,
},
{
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
require.Equal(t, SrcProxyQuit, ts.mp.QuitSource())
return nil
},
},
}
ts.runTests(runners)
}

func TestHandlerReturnError(t *testing.T) {
tests := []struct {
cfg cfgOverrider
errMsg string
cfg cfgOverrider
errMsg string
quitSource ErrorSource
}{
{
cfg: func(config *testConfig) {
config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
return nil, errors.New("mocked error")
}
},
errMsg: "mocked error",
errMsg: "mocked error",
quitSource: SrcProxyErr,
},
{
cfg: func(config *testConfig) {
config.proxyConfig.handler.handleHandshakeResp = func(ctx ConnContext, resp *pnet.HandshakeResp) error {
return errors.New("mocked error")
}
},
errMsg: "mocked error",
errMsg: "mocked error",
quitSource: SrcProxyErr,
},
{
// TODO: make it fail faster.
Expand All @@ -717,7 +753,8 @@ func TestHandlerReturnError(t *testing.T) {
return router.NewStaticRouter(nil), nil
}
},
errMsg: connectErrMsg,
errMsg: connectErrMsg,
quitSource: SrcProxyErr,
},
}
for _, test := range tests {
Expand All @@ -732,6 +769,7 @@ func TestHandlerReturnError(t *testing.T) {
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig)
require.Error(t, err)
require.Equal(t, test.quitSource, ts.mp.QuitSource())
return nil
},
backend: nil,
Expand Down Expand Up @@ -761,6 +799,9 @@ func TestGetBackendIO(t *testing.T) {
if err != nil && len(s) > 0 {
badAddrs[s] = struct{}{}
}
if err != nil {
require.Equal(t, SrcProxyErr, connContext.QuitSource())
}
},
}
mgr := NewBackendConnManager(logger.CreateLoggerForTest(t), handler, 0, &BCConfig{})
Expand Down Expand Up @@ -865,6 +906,12 @@ func TestBackendInactive(t *testing.T) {
return packetIO.Close()
},
},
{
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
require.Equal(t, SrcBackendQuit, ts.mp.QuitSource())
return nil
},
},
}
ts.runTests(runners)
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/proxy/backend/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ const (
capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB"
)

var (
ErrClientConn = errors.New("this is an error from client")
ErrBackendConn = errors.New("this is an error from backend")
)

// UserError is returned to the client.
// err is used to log and userMsg is used to report to the user.
type UserError struct {
Expand Down
Loading

0 comments on commit 14429a3

Please sign in to comment.