diff --git a/server/BUILD.bazel b/server/BUILD.bazel index 79deb82037ab3..e346e35c466b1 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -116,7 +116,6 @@ go_library( "@org_golang_google_grpc//channelz/service", "@org_golang_google_grpc//keepalive", "@org_golang_google_grpc//peer", - "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", ], ) diff --git a/server/conn.go b/server/conn.go index 41133a9bec358..ecd5977f0d101 100644 --- a/server/conn.go +++ b/server/conn.go @@ -172,7 +172,6 @@ func newClientConn(s *Server) *clientConn { status: connStatusDispatching, lastActive: time.Now(), authPlugin: mysql.AuthNativePassword, - quit: make(chan struct{}), ppEnabled: s.cfg.ProxyProtocol.Networks != "", } } @@ -216,8 +215,6 @@ type clientConn struct { sync.RWMutex cancelFunc context.CancelFunc } - // quit is close once clientConn quit Run(). - quit chan struct{} extensions *extension.SessionExtensions // Proxy Protocol Enabled @@ -1096,12 +1093,6 @@ func (cc *clientConn) Run(ctx context.Context) { terror.Log(err) metrics.PanicCounter.WithLabelValues(metrics.LabelSession).Inc() } - if atomic.LoadInt32(&cc.status) != connStatusShutdown { - err := cc.Close() - terror.Log(err) - } - - close(cc.quit) }() // Usually, client connection status changes between [dispatching] <=> [reading]. @@ -1110,13 +1101,6 @@ func (cc *clientConn) Run(ctx context.Context) { // The client connection would detect the events when it fails to change status // by CAS operation, it would then take some actions accordingly. for { - // Close connection between txn when we are going to shutdown server. - if cc.server.inShutdownMode.Load() { - if !cc.ctx.GetSessionVars().InTxn() { - return - } - } - if !atomic.CompareAndSwapInt32(&cc.status, connStatusDispatching, connStatusReading) || // The judge below will not be hit by all means, // But keep it stayed as a reminder and for the code reference for connStatusWaitShutdown. @@ -1126,7 +1110,6 @@ func (cc *clientConn) Run(ctx context.Context) { cc.alloc.Reset() // close connection when idle time is more than wait_timeout - // default 28800(8h), FIXME: should not block at here when we kill the connection. waitTimeout := cc.getSessionVarsWaitTimeout(ctx) cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second) start := time.Now() @@ -1213,6 +1196,22 @@ func (cc *clientConn) Run(ctx context.Context) { } } +// ShutdownOrNotify will Shutdown this client connection, or do its best to notify. +func (cc *clientConn) ShutdownOrNotify() bool { + if (cc.ctx.Status() & mysql.ServerStatusInTrans) > 0 { + return false + } + // If the client connection status is reading, it's safe to shutdown it. + if atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusShutdown) { + return true + } + // If the client connection status is dispatching, we can't shutdown it immediately, + // so set the status to WaitShutdown as a notification, the loop in clientConn.Run + // will detect it and then exit. + atomic.StoreInt32(&cc.status, connStatusWaitShutdown) + return false +} + func errStrForLog(err error, enableRedactLog bool) string { if enableRedactLog { // currently, only ErrParse is considered when enableRedactLog because it may contain sensitive information like diff --git a/server/conn_test.go b/server/conn_test.go index c540b1784793d..fa3b9d5317a96 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -778,6 +778,31 @@ func TestShutDown(t *testing.T) { require.Equal(t, executor.ErrQueryInterrupted, err) } +func TestShutdownOrNotify(t *testing.T) { + store := testkit.CreateMockStore(t) + se, err := session.CreateSession4Test(store) + require.NoError(t, err) + tc := &TiDBContext{ + Session: se, + stmts: make(map[int]*TiDBStatement), + } + cc := &clientConn{ + connectionID: 1, + server: &Server{ + capability: defaultCapability, + }, + status: connStatusWaitShutdown, + } + cc.setCtx(tc) + require.False(t, cc.ShutdownOrNotify()) + cc.status = connStatusReading + require.True(t, cc.ShutdownOrNotify()) + require.Equal(t, connStatusShutdown, cc.status) + cc.status = connStatusDispatching + require.False(t, cc.ShutdownOrNotify()) + require.Equal(t, connStatusWaitShutdown, cc.status) +} + type snapshotCache interface { SnapCacheHitCount() int } diff --git a/server/http_status.go b/server/http_status.go index 20aa534a7827f..8070bd91e2b99 100644 --- a/server/http_status.go +++ b/server/http_status.go @@ -539,7 +539,7 @@ func (s *Server) handleStatus(w http.ResponseWriter, req *http.Request) { // If the server is in the process of shutting down, return a non-200 status. // It is important not to return status{} as acquiring the s.ConnectionCount() // acquires a lock that may already be held by the shutdown process. - if !s.health.Load() { + if s.inShutdownMode { w.WriteHeader(http.StatusInternalServerError) return } diff --git a/server/server.go b/server/server.go index 8bb184261b046..3ab08629b232a 100644 --- a/server/server.go +++ b/server/server.go @@ -72,7 +72,6 @@ import ( "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sys/linux" "github.com/pingcap/tidb/util/timeutil" - uatomic "go.uber.org/atomic" "go.uber.org/zap" "google.golang.org/grpc" ) @@ -130,21 +129,18 @@ type Server struct { driver IDriver listener net.Listener socket net.Listener + rwlock sync.RWMutex concurrentLimiter *TokenLimiter - - rwlock sync.RWMutex - clients map[uint64]*clientConn - - capability uint32 - dom *domain.Domain - globalConnID util.GlobalConnID + clients map[uint64]*clientConn + capability uint32 + dom *domain.Domain + globalConnID util.GlobalConnID statusAddr string statusListener net.Listener statusServer *http.Server grpcServer *grpc.Server - inShutdownMode *uatomic.Bool - health *uatomic.Bool + inShutdownMode bool sessionMapMutex sync.Mutex internalSessions map[interface{}]struct{} @@ -213,8 +209,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { globalConnID: util.NewGlobalConnID(0, true), internalSessions: make(map[interface{}]struct{}, 100), printMDLLogTime: time.Now(), - health: uatomic.NewBool(true), - inShutdownMode: uatomic.NewBool(false), } s.capability = defaultCapability setTxnScope() @@ -402,7 +396,7 @@ func (s *Server) Run() error { } // If error should be reported and exit the server it can be sent on this // channel. Otherwise, end with sending a nil error to signal "done" - errChan := make(chan error, 2) + errChan := make(chan error) go s.startNetworkListener(s.listener, false, errChan) go s.startNetworkListener(s.socket, true, errChan) err := <-errChan @@ -422,7 +416,7 @@ func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, if err != nil { if opErr, ok := err.(*net.OpError); ok { if opErr.Err.Error() == "use of closed network connection" { - if s.inShutdownMode.Load() { + if s.inShutdownMode { errChan <- nil } else { errChan <- err @@ -442,8 +436,6 @@ func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, return } - logutil.BgLogger().Debug("accept new connection success") - clientConn := s.newConn(conn) if isUnixSocket { var ( @@ -515,8 +507,10 @@ func (s *Server) checkAuditPlugin(clientConn *clientConn) error { } func (s *Server) startShutdown() { + s.rwlock.RLock() logutil.BgLogger().Info("setting tidb-server to report unhealthy (shutting-down)") - s.health.Store(false) + s.inShutdownMode = true + s.rwlock.RUnlock() // give the load balancer a chance to receive a few unhealthy health reports // before acquiring the s.rwlock and blocking connections. waitTime := time.Duration(s.cfg.GracefulWaitBeforeShutdown) * time.Second @@ -526,7 +520,12 @@ func (s *Server) startShutdown() { } } -func (s *Server) closeListener() { +// Close closes the server. +func (s *Server) Close() { + s.startShutdown() + s.rwlock.Lock() // prevent new connections + defer s.rwlock.Unlock() + if s.listener != nil { err := s.listener.Close() terror.Log(errors.Trace(err)) @@ -556,34 +555,6 @@ func (s *Server) closeListener() { metrics.ServerEventCounter.WithLabelValues(metrics.EventClose).Inc() } -var gracefulCloseConnectionsTimeout = 15 * time.Second - -// Close closes the server. -func (s *Server) Close() { - s.startShutdown() - s.rwlock.Lock() // // prevent new connections - defer s.rwlock.Unlock() - s.inShutdownMode.Store(true) - s.closeListener() -} - -func (s *Server) registerConn(conn *clientConn) bool { - s.rwlock.Lock() - defer s.rwlock.Unlock() - connections := len(s.clients) - - logger := logutil.BgLogger() - if s.inShutdownMode.Load() { - logger.Info("close connection directly when shutting down") - terror.Log(closeConn(conn, connections)) - return false - } - s.clients[conn.connectionID] = conn - connections = len(s.clients) - metrics.ConnGauge.Set(float64(connections)) - return true -} - // onConn runs in its own goroutine, handles queries from this connection. func (s *Server) onConn(conn *clientConn) { // init the connInfo @@ -612,7 +583,6 @@ func (s *Server) onConn(conn *clientConn) { } ctx := logutil.WithConnID(context.Background(), conn.connectionID) - if err := conn.handshake(ctx); err != nil { conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err) if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil { @@ -654,10 +624,11 @@ func (s *Server) onConn(conn *clientConn) { terror.Log(conn.Close()) logutil.Logger(ctx).Debug("connection closed") }() - - if !s.registerConn(conn) { - return - } + s.rwlock.Lock() + s.clients[conn.connectionID] = conn + connections := len(s.clients) + s.rwlock.Unlock() + metrics.ConnGauge.Set(float64(connections)) sessionVars := conn.ctx.GetSessionVars() sessionVars.ConnectionInfo = conn.connectInfo() @@ -813,7 +784,7 @@ func (s *Server) Kill(connectionID uint64, query bool) { // this, it will end the dispatch loop and exit. atomic.StoreInt32(&conn.status, connStatusWaitShutdown) } - killQuery(conn) + killConn(conn) } // UpdateTLSConfig implements the SessionManager interface. @@ -825,7 +796,7 @@ func (s *Server) getTLSConfig() *tls.Config { return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig)) } -func killQuery(conn *clientConn) { +func killConn(conn *clientConn) { sessVars := conn.ctx.GetSessionVars() atomic.StoreUint32(&sessVars.Killed, 1) conn.mu.RLock() @@ -853,8 +824,7 @@ func (s *Server) KillSysProcesses() { } } -// KillAllConnections implements the SessionManager interface. -// KillAllConnections kills all connections. +// KillAllConnections kills all connections when server is not gracefully shutdown. func (s *Server) KillAllConnections() { logutil.BgLogger().Info("[server] kill all connections.") @@ -865,53 +835,73 @@ func (s *Server) KillAllConnections() { if err := conn.closeWithoutLock(); err != nil { terror.Log(err) } - killQuery(conn) + killConn(conn) } s.KillSysProcesses() } -// DrainClients drain all connections in drainWait. -// After drainWait duration, we kill all connections still not quit explicitly and wait for cancelWait. -func (s *Server) DrainClients(drainWait time.Duration, cancelWait time.Duration) { - logger := logutil.BgLogger() - logger.Info("start drain clients") - - conns := make(map[uint64]*clientConn) - - s.rwlock.Lock() - for k, v := range s.clients { - conns[k] = v - } - s.rwlock.Unlock() +var gracefulCloseConnectionsTimeout = 15 * time.Second - allDone := make(chan struct{}) - quitWaitingForConns := make(chan struct{}) - defer close(quitWaitingForConns) +// TryGracefulDown will try to gracefully close all connection first with timeout. if timeout, will close all connection directly. +func (s *Server) TryGracefulDown() { + ctx, cancel := context.WithTimeout(context.Background(), gracefulCloseConnectionsTimeout) + defer cancel() + done := make(chan struct{}) go func() { - defer close(allDone) - for _, conn := range conns { - select { - case <-conn.quit: - case <-quitWaitingForConns: - return - } - } + s.GracefulDown(ctx, done) }() - select { - case <-allDone: - logger.Info("all sessions quit in drain wait time") - case <-time.After(drainWait): - logger.Info("timeout waiting all sessions quit") + case <-ctx.Done(): + s.KillAllConnections() + case <-done: + return + } +} + +// GracefulDown waits all clients to close. +func (s *Server) GracefulDown(ctx context.Context, done chan struct{}) { + logutil.Logger(ctx).Info("[server] graceful shutdown.") + metrics.ServerEventCounter.WithLabelValues(metrics.EventGracefulDown).Inc() + + count := s.ConnectionCount() + for i := 0; count > 0; i++ { + s.kickIdleConnection() + + count = s.ConnectionCount() + if count == 0 { + break + } + // Print information for every 30s. + if i%30 == 0 { + logutil.Logger(ctx).Info("graceful shutdown...", zap.Int("conn count", count)) + } + ticker := time.After(time.Second) + select { + case <-ctx.Done(): + return + case <-ticker: + } } + close(done) +} - s.KillAllConnections() +func (s *Server) kickIdleConnection() { + var conns []*clientConn + s.rwlock.RLock() + for _, cc := range s.clients { + if cc.ShutdownOrNotify() { + // Shutdowned conn will be closed by us, and notified conn will exist themselves. + conns = append(conns, cc) + } + } + s.rwlock.RUnlock() - select { - case <-allDone: - case <-time.After(cancelWait): - logger.Warn("some sessions do not quit in cancel wait time") + for _, cc := range conns { + err := cc.Close() + if err != nil { + logutil.BgLogger().Error("close connection", zap.Error(err)) + } } } diff --git a/tidb-server/main.go b/tidb-server/main.go index e5ead347a1841..80bcf40392b6c 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -836,23 +836,17 @@ func closeDomainAndStorage(storage kv.Storage, dom *domain.Domain) { terror.Log(errors.Trace(err)) } -var gracefulCloseConnectionsTimeout = 15 * time.Second - func cleanup(svr *server.Server, storage kv.Storage, dom *domain.Domain, graceful bool) { dom.StopAutoAnalyze() - - var drainClientWait time.Duration if graceful { - drainClientWait = 1<<63 - 1 + done := make(chan struct{}) + svr.GracefulDown(context.Background(), done) + // Kill sys processes such as auto analyze. Otherwise, tidb-server cannot exit until auto analyze is finished. + // See https://github.com/pingcap/tidb/issues/40038 for details. + svr.KillSysProcesses() } else { - drainClientWait = gracefulCloseConnectionsTimeout + svr.TryGracefulDown() } - cancelClientWait := time.Second * 1 - svr.DrainClients(drainClientWait, cancelClientWait) - - // Kill sys processes such as auto analyze. Otherwise, tidb-server cannot exit until auto analyze is finished. - // See https://github.com/pingcap/tidb/issues/40038 for details. - svr.KillSysProcesses() plugin.Shutdown(context.Background()) closeDomainAndStorage(storage, dom) disk.CleanUp()