diff --git a/server/http_status.go b/server/http_status.go index b0928c190109f..41a93b0cffe90 100644 --- a/server/http_status.go +++ b/server/http_status.go @@ -71,6 +71,33 @@ func sleepWithCtx(ctx context.Context, d time.Duration) { } } +func (s *Server) listenStatusHTTPServer() error { + s.statusAddr = fmt.Sprintf("%s:%d", s.cfg.Status.StatusHost, s.cfg.Status.StatusPort) + if s.cfg.Status.StatusPort == 0 { + s.statusAddr = fmt.Sprintf("%s:%d", s.cfg.Status.StatusHost, defaultStatusPort) + } + + logutil.BgLogger().Info("for status and metrics report", zap.String("listening on addr", s.statusAddr)) + tlsConfig, err := s.cfg.Security.ToTLSConfig() + if err != nil { + logutil.BgLogger().Error("invalid TLS config", zap.Error(err)) + return errors.Trace(err) + } + tlsConfig = s.setCNChecker(tlsConfig) + + if tlsConfig != nil { + // we need to manage TLS here for cmux to distinguish between HTTP and gRPC. + s.statusListener, err = tls.Listen("tcp", s.statusAddr, tlsConfig) + } else { + s.statusListener, err = net.Listen("tcp", s.statusAddr) + } + if err != nil { + logutil.BgLogger().Info("listen failed", zap.Error(err)) + return errors.Trace(err) + } + return nil +} + func (s *Server) startHTTPServer() { router := mux.NewRouter() @@ -123,13 +150,8 @@ func (s *Server) startHTTPServer() { router.Handle("/mvcc/hex/{hexKey}", mvccTxnHandler{tikvHandlerTool, opMvccGetByHex}) router.Handle("/mvcc/index/{db}/{table}/{index}/{handle}", mvccTxnHandler{tikvHandlerTool, opMvccGetByIdx}) - addr := fmt.Sprintf("%s:%d", s.cfg.Status.StatusHost, s.cfg.Status.StatusPort) - if s.cfg.Status.StatusPort == 0 { - addr = fmt.Sprintf("%s:%d", s.cfg.Status.StatusHost, defaultStatusPort) - } - // HTTP path for web UI. - if host, port, err := net.SplitHostPort(addr); err == nil { + if host, port, err := net.SplitHostPort(s.statusAddr); err == nil { if host == "" { host = "localhost" } @@ -271,40 +293,17 @@ func (s *Server) startHTTPServer() { logutil.BgLogger().Error("write HTTP index page failed", zap.Error(err)) } }) - - logutil.BgLogger().Info("for status and metrics report", zap.String("listening on addr", addr)) - s.setupStatusServerAndRPCServer(addr, serverMux) + s.startStatusServerAndRPCServer(serverMux) } -func (s *Server) setupStatusServerAndRPCServer(addr string, serverMux *http.ServeMux) { - tlsConfig, err := s.cfg.Security.ToTLSConfig() - if err != nil { - logutil.BgLogger().Error("invalid TLS config", zap.Error(err)) - return - } - tlsConfig = s.setCNChecker(tlsConfig) - - var l net.Listener - if tlsConfig != nil { - // we need to manage TLS here for cmux to distinguish between HTTP and gRPC. - l, err = tls.Listen("tcp", addr, tlsConfig) - } else { - l, err = net.Listen("tcp", addr) - } - if err != nil { - logutil.BgLogger().Info("listen failed", zap.Error(err)) - return - } - if tlsConfig != nil { - logutil.BgLogger().Info("HTTP/gRPC status server secure connection is enabled", zap.Bool("CN verification enabled", tlsConfig.VerifyPeerCertificate != nil)) - } - m := cmux.New(l) +func (s *Server) startStatusServerAndRPCServer(serverMux *http.ServeMux) { + m := cmux.New(s.statusListener) // Match connections in order: // First HTTP, and otherwise grpc. httpL := m.Match(cmux.HTTP1Fast()) grpcL := m.Match(cmux.Any()) - s.statusServer = &http.Server{Addr: addr, Handler: CorsHandler{handler: serverMux, cfg: s.cfg}} + s.statusServer = &http.Server{Addr: s.statusAddr, Handler: CorsHandler{handler: serverMux, cfg: s.cfg}} s.grpcServer = NewRPCServer(s.cfg, s.dom, s) go util.WithRecovery(func() { @@ -317,7 +316,7 @@ func (s *Server) setupStatusServerAndRPCServer(addr string, serverMux *http.Serv logutil.BgLogger().Error("http server error", zap.Error(err)) }, nil) - err = m.Serve() + err := m.Serve() if err != nil { logutil.BgLogger().Error("start status/rpc server error", zap.Error(err)) } diff --git a/server/server.go b/server/server.go index 12458775584ac..f4685303864b4 100644 --- a/server/server.go +++ b/server/server.go @@ -116,8 +116,11 @@ type Server struct { clients map[uint32]*clientConn capability uint32 dom *domain.Domain - statusServer *http.Server - grpcServer *grpc.Server + + statusAddr string + statusListener net.Listener + statusServer *http.Server + grpcServer *grpc.Server } // ConnectionCount gets current connection count. @@ -256,6 +259,9 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s.listener = pplistener } + if s.cfg.Status.ReportStatus && err == nil { + err = s.listenStatusHTTPServer() + } if err != nil { return nil, errors.Trace(err) } diff --git a/server/tidb_test.go b/server/tidb_test.go index 961946e411147..bd83e123273ec 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" "io/ioutil" "math/big" "net/http" @@ -180,6 +181,27 @@ func (ts *tidbTestSuite) TestStatusAPI(c *C) { ts.runTestStatusAPI(c) } +func (ts *tidbTestSuite) TestStatusPort(c *C) { + var err error + ts.store, err = mockstore.NewMockTikvStore() + session.DisableStats4Test() + c.Assert(err, IsNil) + ts.domain, err = session.BootstrapSession(ts.store) + c.Assert(err, IsNil) + ts.tidbdrv = NewTiDBDriver(ts.store) + cfg := config.NewConfig() + cfg.Port = genPort() + cfg.Status.ReportStatus = true + cfg.Status.StatusPort = ts.statusPort + cfg.Performance.TCPKeepAlive = true + + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, + fmt.Sprintf("listen tcp 0.0.0.0:%d: bind: address already in use", ts.statusPort)) + c.Assert(server, IsNil) +} + func (ts *tidbTestSuite) TestStatusAPIWithTLS(c *C) { caCert, caKey, err := generateCert(0, "TiDB CA 2", nil, nil, "/tmp/ca-key-2.pem", "/tmp/ca-cert-2.pem") c.Assert(err, IsNil)