diff --git a/lib/config/namespace.go b/lib/config/namespace.go index 0d783832..a6348f9a 100644 --- a/lib/config/namespace.go +++ b/lib/config/namespace.go @@ -22,9 +22,8 @@ type FrontendNamespace struct { } type BackendNamespace struct { - Instances []string `yaml:"instances" json:"instances" toml:"instances"` - SelectorType string `yaml:"selector-type" json:"selector-type" toml:"selector-type"` - Security TLSConfig `yaml:"security" json:"security" toml:"security"` + Instances []string `yaml:"instances" json:"instances" toml:"instances"` + Security TLSConfig `yaml:"security" json:"security" toml:"security"` //HealthCheck HealthCheck `yaml:"health-check" json:"health-check" toml:"health-check"` } diff --git a/lib/config/namespace_test.go b/lib/config/namespace_test.go index 86560682..41aaf7f5 100644 --- a/lib/config/namespace_test.go +++ b/lib/config/namespace_test.go @@ -21,8 +21,7 @@ var testNamespaceConfig = Namespace{ }, }, Backend: BackendNamespace{ - Instances: []string{"127.0.0.1:4000", "127.0.0.1:4001"}, - SelectorType: "random", + Instances: []string{"127.0.0.1:4000", "127.0.0.1:4001"}, Security: TLSConfig{ CA: "t", Cert: "t", diff --git a/lib/config/proxy.go b/lib/config/proxy.go index 4901db9a..0db416e3 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -46,6 +46,7 @@ type KeepAlive struct { } type ProxyServerOnline struct { + RequireBackendTLS bool `yaml:"require-backend-tls,omitempty" toml:"require-backend-tls,omitempty" json:"require-backend-tls,omitempty"` MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty"` ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty"` FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"` @@ -62,17 +63,12 @@ type ProxyServerOnline struct { type ProxyServer struct { Addr string `yaml:"addr,omitempty" toml:"addr,omitempty" json:"addr,omitempty"` PDAddrs string `yaml:"pd-addrs,omitempty" toml:"pd-addrs,omitempty" json:"pd-addrs,omitempty"` - ServerVersion string `yaml:"server-version,omitempty" toml:"server-version,omitempty" json:"server-version,omitempty"` - RequireBackendTLS bool `yaml:"require-backend-tls,omitempty" toml:"require-backend-tls,omitempty" json:"require-backend-tls,omitempty"` ProxyServerOnline `yaml:",inline" toml:",inline" json:",inline"` } type API struct { - Addr string `yaml:"addr,omitempty" toml:"addr,omitempty" json:"addr,omitempty"` - User string `yaml:"user,omitempty" toml:"user,omitempty" json:"user,omitempty"` - Password string `yaml:"password,omitempty" toml:"password,omitempty" json:"password,omitempty"` - EnableBasicAuth bool `yaml:"enable-basic-auth,omitempty" toml:"enable-basic-auth,omitempty" json:"enable-basic-auth,omitempty"` - ProxyProtocol string `yaml:"proxy-protocol,omitempty" toml:"proxy-protocol,omitempty" json:"proxy-protocol,omitempty"` + Addr string `yaml:"addr,omitempty" toml:"addr,omitempty" json:"addr,omitempty"` + ProxyProtocol string `yaml:"proxy-protocol,omitempty" toml:"proxy-protocol,omitempty" json:"proxy-protocol,omitempty"` } type Advance struct { diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index 52dc2734..b62ab628 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -18,10 +18,10 @@ var testProxyConfig = Config{ IgnoreWrongNamespace: true, }, Proxy: ProxyServer{ - Addr: "0.0.0.0:4000", - PDAddrs: "127.0.0.1:4089", - RequireBackendTLS: true, + Addr: "0.0.0.0:4000", + PDAddrs: "127.0.0.1:4089", ProxyServerOnline: ProxyServerOnline{ + RequireBackendTLS: true, MaxConnections: 1, FrontendKeepalive: KeepAlive{Enabled: true}, ProxyProtocol: "v2", @@ -30,10 +30,7 @@ var testProxyConfig = Config{ }, }, API: API{ - Addr: "0.0.0.0:3080", - EnableBasicAuth: false, - User: "user", - Password: "pwd", + Addr: "0.0.0.0:3080", }, Metrics: Metrics{ MetricsAddr: "127.0.0.1:9021", diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 6cddc6d8..fc283196 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/systimemon" "github.com/pingcap/tiproxy/lib/util/waitgroup" "github.com/prometheus/client_golang/prometheus" @@ -52,12 +53,12 @@ func NewMetricsManager() *MetricsManager { var registerOnce = &sync.Once{} // Init registers metrics and pushes metrics to prometheus. -func (mm *MetricsManager) Init(ctx context.Context, logger *zap.Logger, metricsAddr string, metricsInterval uint, proxyAddr string) { +func (mm *MetricsManager) Init(ctx context.Context, logger *zap.Logger, proxyAddr string, cfg config.Metrics, cfgch <-chan *config.Config) { mm.logger = logger registerOnce.Do(registerProxyMetrics) ctx, mm.cancel = context.WithCancel(ctx) mm.setupMonitor(ctx) - mm.pushMetric(ctx, metricsAddr, time.Duration(metricsInterval)*time.Second, proxyAddr) + mm.pushMetric(ctx, proxyAddr, cfg, cfgch) } // Close stops all goroutines. @@ -89,17 +90,64 @@ func (mm *MetricsManager) setupMonitor(ctx context.Context) { } // pushMetric pushes metrics in background. -func (mm *MetricsManager) pushMetric(ctx context.Context, addr string, interval time.Duration, proxyAddr string) { - if interval == time.Duration(0) || len(addr) == 0 { - mm.logger.Info("disable Prometheus push client") - return - } - mm.logger.Info("start prometheus push client", zap.String("server addr", addr), zap.String("interval", interval.String())) +func (mm *MetricsManager) pushMetric(ctx context.Context, proxyAddr string, cfg config.Metrics, cfgch <-chan *config.Config) { mm.wg.Run(func() { - prometheusPushClient(ctx, mm.logger, addr, interval, proxyAddr) + proxyInstance := instanceName(proxyAddr) + addr := cfg.MetricsAddr + interval := time.Duration(cfg.MetricsInterval) * time.Second + pusher := mm.buildPusher(addr, interval, proxyInstance) + + for ctx.Err() == nil { + select { + case newCfg := <-cfgch: + if newCfg == nil { + return + } + interval = time.Duration(newCfg.Metrics.MetricsInterval) * time.Second + if addr != newCfg.Metrics.MetricsAddr { + addr = newCfg.Metrics.MetricsAddr + pusher = mm.buildPusher(addr, interval, proxyInstance) + } + default: + } + + // Wait until the config is legal. + if interval == 0 || pusher == nil { + select { + case <-time.After(time.Second): + continue + case <-ctx.Done(): + return + } + } + + if err := pusher.Push(); err != nil { + mm.logger.Error("could not push metrics to prometheus pushgateway", zap.Error(err)) + } + select { + case <-time.After(interval): + case <-ctx.Done(): + return + } + } }) } +func (mm *MetricsManager) buildPusher(addr string, interval time.Duration, proxyInstance string) *push.Pusher { + var pusher *push.Pusher + if len(addr) > 0 { + // Create a new pusher when the address changes. + mm.logger.Info("start prometheus push client", zap.String("server addr", addr), zap.Stringer("interval", interval)) + pusher = push.New(addr, "tiproxy") + pusher = pusher.Gatherer(prometheus.DefaultGatherer) + pusher = pusher.Grouping("instance", proxyInstance) + } else { + mm.logger.Info("disable prometheus push client") + pusher = nil + } + return pusher +} + // registerProxyMetrics registers metrics. func registerProxyMetrics() { prometheus.DefaultRegisterer.Unregister(collectors.NewGoCollector()) @@ -122,25 +170,6 @@ func registerProxyMetrics() { prometheus.MustRegister(MigrateDurationHistogram) } -// prometheusPushClient pushes metrics to Prometheus Pushgateway. -func prometheusPushClient(ctx context.Context, logger *zap.Logger, addr string, interval time.Duration, proxyAddr string) { - job := "tiproxy" - pusher := push.New(addr, job) - pusher = pusher.Gatherer(prometheus.DefaultGatherer) - pusher = pusher.Grouping("instance", instanceName(proxyAddr)) - for ctx.Err() == nil { - err := pusher.Push() - if err != nil { - logger.Error("could not push metrics to prometheus pushgateway", zap.String("err", err.Error())) - } - select { - case <-time.After(interval): - case <-ctx.Done(): - return - } - } -} - func instanceName(proxyAddr string) string { hostname, err := os.Hostname() if err != nil { diff --git a/pkg/metrics/metrics_test.go b/pkg/metrics/metrics_test.go index daa7a836..6e381caf 100644 --- a/pkg/metrics/metrics_test.go +++ b/pkg/metrics/metrics_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/logger" "github.com/stretchr/testify/require" ) @@ -21,64 +22,80 @@ import ( func TestPushMetrics(t *testing.T) { proxyAddr := "0.0.0.0:6000" labelName := fmt.Sprintf("%s_%s_maxprocs", ModuleProxy, LabelServer) - hostname, err := os.Hostname() - require.NoError(t, err) - expectedPath := fmt.Sprintf("/metrics/job/tiproxy/instance/%s_6000", hostname) - bodyCh := make(chan string) - pgwOK := httptest.NewServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - require.NoError(t, err) - bodyCh <- string(body) - require.Equal(t, expectedPath, r.URL.EscapedPath()) - w.Header().Set("Content-Type", `text/plain; charset=utf-8`) - w.WriteHeader(http.StatusOK) - }), - ) - defer pgwOK.Close() + bodyCh1, bodyCh2 := make(chan string), make(chan string) + pgwOK1, pgwOK2 := setupServer(t, bodyCh1), setupServer(t, bodyCh2) log, _ := logger.CreateLoggerForTest(t) tests := []struct { metricsAddr string metricsInterval uint - pushed bool + pushedCh chan string }{ { - metricsAddr: pgwOK.URL, + metricsAddr: pgwOK1.URL, metricsInterval: 1, - pushed: true, + pushedCh: bodyCh1, }, { - metricsAddr: "", + metricsAddr: pgwOK1.URL, + metricsInterval: 0, + pushedCh: nil, + }, + { + metricsAddr: pgwOK2.URL, metricsInterval: 1, - pushed: false, + pushedCh: bodyCh2, }, { - metricsAddr: pgwOK.URL, - metricsInterval: 0, - pushed: false, + metricsAddr: "", + metricsInterval: 1, + pushedCh: nil, }, } + mm := NewMetricsManager() + cfgCh := make(chan *config.Config, 1) + mm.Init(context.Background(), log, proxyAddr, config.Metrics{}, cfgCh) for _, tt := range tests { - for len(bodyCh) > 0 { - <-bodyCh + cfgCh <- &config.Config{ + Metrics: config.Metrics{ + MetricsAddr: tt.metricsAddr, + MetricsInterval: tt.metricsInterval, + }, } - mm := NewMetricsManager() - mm.Init(context.Background(), log, tt.metricsAddr, tt.metricsInterval, proxyAddr) - if tt.pushed { + if tt.pushedCh != nil { select { - case body := <-bodyCh: + case body := <-tt.pushedCh: require.Contains(t, body, labelName) case <-time.After(2 * time.Second): t.Fatal("not pushed") } } else { select { - case <-bodyCh: - t.Fatal("pushed") + case <-bodyCh1: + t.Fatal("pushed 1") + case <-bodyCh2: + t.Fatal("pushed 2") case <-time.After(2 * time.Second): } } - mm.Close() } + mm.Close() +} + +func setupServer(t *testing.T, bodyCh chan string) *httptest.Server { + hostname, err := os.Hostname() + require.NoError(t, err) + expectedPath := fmt.Sprintf("/metrics/job/tiproxy/instance/%s_6000", hostname) + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + bodyCh <- string(body) + require.Equal(t, expectedPath, r.URL.EscapedPath()) + w.Header().Set("Content-Type", `text/plain; charset=utf-8`) + w.WriteHeader(http.StatusOK) + }), + ) + t.Cleanup(server.Close) + return server } diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 50293d0e..cb2665ac 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -82,14 +82,12 @@ type HandshakeHandler interface { } type DefaultHandshakeHandler struct { - nsManager *namespace.NamespaceManager - serverVersion string + nsManager *namespace.NamespaceManager } -func NewDefaultHandshakeHandler(nsManager *namespace.NamespaceManager, serverVersion string) *DefaultHandshakeHandler { +func NewDefaultHandshakeHandler(nsManager *namespace.NamespaceManager) *DefaultHandshakeHandler { return &DefaultHandshakeHandler{ - nsManager: nsManager, - serverVersion: serverVersion, + nsManager: nsManager, } } @@ -128,9 +126,6 @@ func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability { } func (handler *DefaultHandshakeHandler) GetServerVersion() string { - if len(handler.serverVersion) > 0 { - return handler.serverVersion - } // TiProxy sends the server version before getting the router, so we don't know which router to get. // Just get the default one. if ns, ok := handler.nsManager.GetNamespace("default"); ok { diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 3ba25b44..cd9fbf43 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -31,6 +31,7 @@ type serverState struct { connID uint64 maxConnections uint64 connBufferSize int + requireBackendTLS bool tcpKeepAlive bool proxyProtocol bool gracefulWait int @@ -38,14 +39,13 @@ type serverState struct { } type SQLServer struct { - listeners []net.Listener - addrs []string - logger *zap.Logger - certMgr *cert.CertManager - hsHandler backend.HandshakeHandler - requireBackendTLS bool - wg waitgroup.WaitGroup - cancelFunc context.CancelFunc + listeners []net.Listener + addrs []string + logger *zap.Logger + certMgr *cert.CertManager + hsHandler backend.HandshakeHandler + wg waitgroup.WaitGroup + cancelFunc context.CancelFunc mu serverState } @@ -53,12 +53,10 @@ type SQLServer struct { // NewSQLServer creates a new SQLServer. func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.CertManager, hsHandler backend.HandshakeHandler) (*SQLServer, error) { var err error - s := &SQLServer{ - logger: logger, - certMgr: certMgr, - hsHandler: hsHandler, - requireBackendTLS: cfg.RequireBackendTLS, + logger: logger, + certMgr: certMgr, + hsHandler: hsHandler, mu: serverState{ connID: 0, clients: make(map[uint64]*client.ClientConnection), @@ -83,6 +81,7 @@ func (s *SQLServer) reset(cfg *config.ProxyServerOnline) { s.mu.Lock() s.mu.tcpKeepAlive = cfg.FrontendKeepalive.Enabled s.mu.maxConnections = cfg.MaxConnections + s.mu.requireBackendTLS = cfg.RequireBackendTLS s.mu.proxyProtocol = cfg.ProxyProtocol != "" s.mu.gracefulWait = cfg.GracefulWaitBeforeShutdown s.mu.healthyKeepAlive = cfg.BackendHealthyKeepalive @@ -162,13 +161,13 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.hsHandler, connID, addr, &backend.BCConfig{ ProxyProtocol: s.mu.proxyProtocol, - RequireBackendTLS: s.requireBackendTLS, + RequireBackendTLS: s.mu.requireBackendTLS, HealthyKeepAlive: s.mu.healthyKeepAlive, UnhealthyKeepAlive: s.mu.unhealthyKeepAlive, ConnBufferSize: s.mu.connBufferSize, }) s.mu.clients[connID] = clientConn - logger.Info("new connection", zap.Bool("proxy-protocol", s.mu.proxyProtocol)) + logger.Info("new connection", zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.Bool("require_backend_tls", s.mu.requireBackendTLS)) s.mu.Unlock() metrics.ConnGauge.Inc() diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index 068d4535..9263fdd3 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -27,7 +27,7 @@ import ( func TestGracefulShutdown(t *testing.T) { // Graceful shutdown finishes immediately if there's no connection. lg, _ := logger.CreateLoggerForTest(t) - hsHandler := backend.NewDefaultHandshakeHandler(nil, "") + hsHandler := backend.NewDefaultHandshakeHandler(nil) server, err := NewSQLServer(lg, config.ProxyServer{ ProxyServerOnline: config.ProxyServerOnline{ GracefulWaitBeforeShutdown: 10, @@ -130,6 +130,37 @@ func TestMultiAddr(t *testing.T) { certManager.Close() } +func TestWatchCfg(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + hsHandler := backend.NewDefaultHandshakeHandler(nil) + cfgch := make(chan *config.Config) + server, err := NewSQLServer(lg, config.ProxyServer{}, nil, hsHandler) + require.NoError(t, err) + server.Run(context.Background(), cfgch) + cfg := &config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + RequireBackendTLS: true, + MaxConnections: 100, + ConnBufferSize: 1024 * 1024, + ProxyProtocol: "v2", + GracefulWaitBeforeShutdown: 100, + }, + }, + } + cfgch <- cfg + require.Eventually(t, func() bool { + server.mu.RLock() + defer server.mu.RUnlock() + return server.mu.requireBackendTLS == cfg.Proxy.RequireBackendTLS && + server.mu.maxConnections == cfg.Proxy.MaxConnections && + server.mu.connBufferSize == cfg.Proxy.ConnBufferSize && + server.mu.proxyProtocol == (cfg.Proxy.ProxyProtocol != "") && + server.mu.gracefulWait == cfg.Proxy.GracefulWaitBeforeShutdown + }, 3*time.Second, 10*time.Millisecond) + require.NoError(t, server.Close()) +} + func TestRecoverPanic(t *testing.T) { lg, text := logger.CreateLoggerForTest(t) certManager := cert.NewCertManager() diff --git a/pkg/server/api/namespace_test.go b/pkg/server/api/namespace_test.go index 2f537909..06b984cc 100644 --- a/pkg/server/api/namespace_test.go +++ b/pkg/server/api/namespace_test.go @@ -35,7 +35,7 @@ func TestNamespace(t *testing.T) { doHTTP(t, http.MethodGet, "/api/admin/namespace/dge", nil, func(t *testing.T, r *http.Response) { all, err := io.ReadAll(r.Body) require.NoError(t, err) - require.Equal(t, `{"namespace":"dge","frontend":{"user":"","security":{}},"backend":{"instances":null,"selector-type":"","security":{}}}`, string(all)) + require.Equal(t, `{"namespace":"dge","frontend":{"user":"","security":{}},"backend":{"instances":null,"security":{}}}`, string(all)) require.Equal(t, http.StatusOK, r.StatusCode) }) diff --git a/pkg/server/api/server.go b/pkg/server/api/server.go index a3e991b8..c33384e6 100644 --- a/pkg/server/api/server.go +++ b/pkg/server/api/server.go @@ -203,9 +203,6 @@ func (h *Server) grpcServer(ctx *gin.Context) { func (h *Server) registerAPI(g *gin.RouterGroup, cfg config.API, nsmgr *mgrns.NamespaceManager, cfgmgr *mgrcfg.ConfigManager) { { adminGroup := g.Group("admin") - if cfg.EnableBasicAuth { - adminGroup.Use(gin.BasicAuth(gin.Accounts{cfg.User: cfg.Password})) - } h.registerNamespace(adminGroup.Group("namespace")) h.registerConfig(adminGroup.Group("config")) } diff --git a/pkg/server/server.go b/pkg/server/server.go index ede2c2a1..209524bd 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -71,7 +71,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) cfg := srv.ConfigManager.GetConfig() // setup metrics - srv.MetricsManager.Init(ctx, lg.Named("metrics"), cfg.Metrics.MetricsAddr, cfg.Metrics.MetricsInterval, cfg.Proxy.Addr) + srv.MetricsManager.Init(ctx, lg.Named("metrics"), cfg.Proxy.Addr, cfg.Metrics, srv.ConfigManager.WatchConfig()) metrics.ServerEventCounter.WithLabelValues(metrics.EventStart).Inc() // setup certs @@ -109,8 +109,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) nsc := &config.Namespace{ Namespace: "default", Backend: config.BackendNamespace{ - Instances: []string{}, - SelectorType: "random", + Instances: []string{}, }, } if err = srv.ConfigManager.SetNamespace(ctx, nsc.Namespace, nsc); err != nil { @@ -132,7 +131,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) if handler != nil { hsHandler = handler } else { - hsHandler = backend.NewDefaultHandshakeHandler(srv.NamespaceManager, cfg.Proxy.ServerVersion) + hsHandler = backend.NewDefaultHandshakeHandler(srv.NamespaceManager) } srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, hsHandler) if err != nil {