From c9c5817c9049e4a0f1f6af21911f03a67b15abd2 Mon Sep 17 00:00:00 2001 From: disksing Date: Mon, 30 Oct 2023 15:21:35 +0800 Subject: [PATCH] auth: reconnect backend (#389) Signed-off-by: disksing --- pkg/proxy/backend/authenticator.go | 24 +++++++++++++++++++++--- pkg/proxy/backend/handshake_handler.go | 14 ++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 2889bdce..00d482e4 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -12,6 +12,7 @@ import ( "time" "github.com/go-mysql-org/go-mysql/mysql" + gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tiproxy/lib/util/errors" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" @@ -159,6 +160,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte auth.attrs = clientResp.Attrs auth.zstdLevel = clientResp.ZstdLevel +RECONNECT: + // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second) if err != nil { @@ -214,7 +217,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte pktIdx := 0 loop: for { - serverPkt, err := forwardMsg(backendIO, clientIO) + serverPkt, err := backendIO.ReadPacket() if err != nil { // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence @@ -223,6 +226,23 @@ loop: } return err } + var packetErr error + if serverPkt[0] == pnet.ErrHeader.Byte() { + packetErr = pnet.ParseErrorPacket(serverPkt) + if handshakeHandler.HandleHandshakeErr(cctx, packetErr.(*gomysql.MyError)) { + logger.Warn("handle handshake error, start reconnect", zap.Error(err)) + backendIO.Close() + goto RECONNECT + } + } + err = clientIO.WritePacket(serverPkt, true) + if err != nil { + return err + } + if packetErr != nil { + return packetErr + } + pktIdx++ switch serverPkt[0] { case pnet.OKHeader.Byte(): @@ -233,8 +253,6 @@ loop: return err } return nil - case pnet.ErrHeader.Byte(): - return pnet.ParseErrorPacket(serverPkt) default: // mysql.AuthSwitchRequest, ShaCommand if serverPkt[0] == pnet.AuthSwitchHeader.Byte() { pluginName = string(serverPkt[1 : bytes.IndexByte(serverPkt[1:], 0)+1]) diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 062600b2..dbe42cce 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -4,6 +4,7 @@ package backend import ( + gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/pkg/manager/namespace" "github.com/pingcap/tiproxy/pkg/manager/router" @@ -70,6 +71,7 @@ type ConnContext interface { type HandshakeHandler interface { HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error + HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool // return true means retry connect GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) OnHandshake(ctx ConnContext, to string, err error) OnConnClose(ctx ConnContext) error @@ -94,6 +96,10 @@ func (handler *DefaultHandshakeHandler) HandleHandshakeResp(ConnContext, *pnet.H return nil } +func (handler *DefaultHandshakeHandler) HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool { + return false +} + func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { ns, ok := handler.nsManager.GetNamespaceByUser(resp.User) if !ok { @@ -142,6 +148,7 @@ type CustomHandshakeHandler struct { onTraffic func(ConnContext) onConnClose func(ConnContext) error handleHandshakeResp func(ctx ConnContext, resp *pnet.HandshakeResp) error + handleHandshakeErr func(ctx ConnContext, err *gomysql.MyError) bool getCapability func() pnet.Capability getServerVersion func() string } @@ -179,6 +186,13 @@ func (h *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet return nil } +func (h *CustomHandshakeHandler) HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool { + if h.handleHandshakeErr != nil { + return h.handleHandshakeErr(ctx, err) + } + return false +} + func (h *CustomHandshakeHandler) GetCapability() pnet.Capability { if h.getCapability != nil { return h.getCapability()