From d7e2e0fcec8b44bcc82ab58e2289a806e2073be1 Mon Sep 17 00:00:00 2001 From: Caleb Lloyd Date: Wed, 24 Aug 2022 17:00:29 -0400 Subject: [PATCH] implement prometheus.Gather interface Signed-off-by: Caleb Lloyd --- surveyor/surveyor.go | 146 ++++++++++++++++++++++---------------- surveyor/surveyor_test.go | 55 ++++++++++---- 2 files changed, 127 insertions(+), 74 deletions(-) diff --git a/surveyor/surveyor.go b/surveyor/surveyor.go index 0f56567..b7de425 100644 --- a/surveyor/surveyor.go +++ b/surveyor/surveyor.go @@ -15,9 +15,11 @@ package surveyor import ( + "context" "crypto/tls" "crypto/x509" "encoding/base64" + "errors" "fmt" "net" "net/http" @@ -31,6 +33,7 @@ import ( "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + dto "github.com/prometheus/client_model/go" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" ) @@ -73,7 +76,8 @@ type Options struct { JetStreamConfigDir string Accounts bool Logger *logrus.Logger - ConstLabels prometheus.Labels + ConstLabels prometheus.Labels // not exposed by CLI + DisableHTTPServer bool // not exposed by CLI } // GetDefaultOptions returns the default set of options @@ -100,8 +104,8 @@ type Surveyor struct { sync.Mutex opts Options logger *logrus.Logger - nc *nats.Conn - http net.Listener + listener net.Listener + httpServer *http.Server promRegistry *prometheus.Registry reconnectCtr *prometheus.CounterVec statzC *StatzCollector @@ -167,6 +171,10 @@ func connect(opts *Options, reconnectCtr *prometheus.CounterVec) (*nats.Conn, er // NewSurveyor creates a surveyor func NewSurveyor(opts *Options) (*Surveyor, error) { + if opts.URLs == "" { + return nil, fmt.Errorf("surveyor URLs field is required") + } + promRegistry := prometheus.NewRegistry() reconnectCtr := prometheus.NewCounterVec(prometheus.CounterOpts{ Name: prometheus.BuildFQName("nats", "survey", "nats_reconnects"), @@ -174,19 +182,12 @@ func NewSurveyor(opts *Options) (*Surveyor, error) { ConstLabels: opts.ConstLabels, }, []string{"name"}) promRegistry.MustRegister(reconnectCtr) - nc, err := connect(opts, reconnectCtr) - if err != nil { - return nil, err - } return &Surveyor{ - nc: nc, opts: *opts, logger: opts.Logger, promRegistry: promRegistry, reconnectCtr: reconnectCtr, - observations: []*ServiceObsListener{}, observationMetrics: NewServiceObservationMetrics(promRegistry, opts.ConstLabels), - jsAPIAudits: []*JSAdvisoryListener{}, jsAPIMetrics: NewJetStreamAdvisoryMetrics(promRegistry, opts.ConstLabels), }, nil } @@ -196,26 +197,18 @@ func (s *Surveyor) createStatszCollector() error { return nil } + nc, err := connect(&s.opts, s.reconnectCtr) + if err != nil { + return err + } + if !s.opts.Accounts { s.logger.Debugln("Skipping per-account exports") } - s.Lock() - s.statzC = NewStatzCollector(s.nc, s.logger, s.opts.ExpectedServers, s.opts.PollTimeout, s.opts.Accounts, s.opts.ConstLabels) - s.Unlock() - - err := s.promRegistry.Register(s.statzC) - for i := 0; i < 50 && err != nil; i++ { - if _, ok := err.(prometheus.AlreadyRegisteredError); ok { - // ignore - return nil - } - - s.logger.Warnf("Error registering statsz collector, will retry after 500ms: %v", err) - time.Sleep(500 * time.Millisecond) - err = s.promRegistry.Register(s.statzC) - } - return err + s.statzC = NewStatzCollector(nc, s.logger, s.opts.ExpectedServers, s.opts.PollTimeout, s.opts.Accounts, s.opts.ConstLabels) + s.promRegistry.MustRegister(s.statzC) + return nil } // generates the TLS config for https @@ -300,9 +293,7 @@ func (s *Surveyor) httpAuthMiddleware(next http.Handler) http.Handler { func (s *Surveyor) httpConcurrentPollBlockMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - s.Lock() sz := s.statzC - s.Unlock() if sz == nil { next.ServeHTTP(rw, r) @@ -336,6 +327,7 @@ func (s *Surveyor) startHTTP() error { var err error var proto string var config *tls.Config + var listener net.Listener hp = net.JoinHostPort(s.opts.ListenAddress, strconv.Itoa(s.opts.ListenPort)) @@ -343,50 +335,44 @@ func (s *Surveyor) startHTTP() error { // key provided. if s.opts.HTTPCertFile != "" { proto = "https" - // debug - s.logger.Debugln("Certificate file specfied; using https.") + s.logger.Debugln("Certificate file specified; using https.") config, err = s.generateHTTPTLSConfig() if err != nil { return err } - s.http, err = tls.Listen("tcp", hp, config) + listener, err = tls.Listen("tcp", hp, config) } else { proto = "http" - - // debug - s.logger.Debugln("No certificate file specified; using http.") - s.http, err = net.Listen("tcp", hp) + s.logger.Debugln("No certificate file specified; using listener.") + listener, err = net.Listen("tcp", hp) } - s.logger.Infof("Prometheus exporter listening at %s://%s/metrics", proto, hp) - if err != nil { s.logger.Errorf("can't start HTTP listener: %v", err) return err } + s.listener = listener + s.logger.Infof("Prometheus exporter listening at %s://%s/metrics", proto, hp) + mux := http.NewServeMux() mux.Handle("/metrics", s.getScrapeHandler()) mux.HandleFunc("/healthz", func(resp http.ResponseWriter, req *http.Request) { resp.Write([]byte("ok")) }) - srv := &http.Server{ + httpServer := &http.Server{ Addr: hp, Handler: mux, MaxHeaderBytes: 1 << 20, TLSConfig: config, } + s.httpServer = httpServer - sHTTP := s.http go func() { - for i := 0; i < 10; i++ { - var err error - if err = srv.Serve(sHTTP); err != nil { - // In a test environment, this can fail because the server is already running. - // debugf - s.logger.Errorf("Unable to start HTTP server (may already be running): %v", err) - } + err := httpServer.Serve(listener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Errorf("Unable to start HTTP server (may already be running): %v", err) } }() @@ -394,6 +380,7 @@ func (s *Surveyor) startHTTP() error { } func (s *Surveyor) startJetStreamAdvisories() error { + s.jsAPIAudits = []*JSAdvisoryListener{} s.jsAPIMetrics.jsAdvisoriesGauge.Set(0) dir := s.opts.JetStreamConfigDir @@ -439,6 +426,7 @@ func (s *Surveyor) startJetStreamAdvisories() error { } func (s *Surveyor) startObservations() error { + s.observations = []*ServiceObsListener{} s.observationMetrics.observationsGauge.Set(0) dir := s.opts.ObservationConfigDir @@ -493,20 +481,31 @@ func (s *Surveyor) startObservations() error { // Start starts the surveyor func (s *Surveyor) Start() error { - if err := s.startHTTP(); err != nil { - return err + s.Lock() + defer s.Unlock() + + if s.statzC == nil { + if err := s.createStatszCollector(); err != nil { + return err + } } - if err := s.createStatszCollector(); err != nil { - return err + if s.observations == nil { + if err := s.startObservations(); err != nil { + return err + } } - if err := s.startObservations(); err != nil { - return err + if s.jsAPIAudits == nil { + if err := s.startJetStreamAdvisories(); err != nil { + return err + } } - if err := s.startJetStreamAdvisories(); err != nil { - return err + if !s.opts.DisableHTTPServer && s.listener == nil && s.httpServer == nil { + if err := s.startHTTP(); err != nil { + return err + } } return nil @@ -515,13 +514,40 @@ func (s *Surveyor) Start() error { // Stop stops the surveyor func (s *Surveyor) Stop() { s.Lock() + defer s.Unlock() + + if s.httpServer != nil { + _ = s.httpServer.Shutdown(context.Background()) + s.httpServer = nil + } + + if s.listener != nil { + _ = s.listener.Close() + s.listener = nil + } + + if s.statzC != nil { + s.promRegistry.Unregister(s.statzC) + s.statzC.nc.Close() + s.statzC = nil + } + + if s.observations != nil { + for _, o := range s.observations { + o.nc.Close() + } + s.observations = nil + } - for _, o := range s.observations { - o.nc.Close() + if s.jsAPIAudits != nil { + for _, j := range s.jsAPIAudits { + j.nc.Close() + } + s.jsAPIAudits = nil } +} - s.promRegistry.Unregister(s.statzC) - s.http.Close() - s.nc.Drain() - s.Unlock() +// Gather implements the prometheus.Gatherer interface +func (s *Surveyor) Gather() ([]*dto.MetricFamily, error) { + return s.promRegistry.Gather() } diff --git a/surveyor/surveyor_test.go b/surveyor/surveyor_test.go index 7f7e402..4567047 100644 --- a/surveyor/surveyor_test.go +++ b/surveyor/surveyor_test.go @@ -162,6 +162,24 @@ func TestSurveyor_Basic(t *testing.T) { } } +func TestSurveyor_StartTwice(t *testing.T) { + sc := st.NewSuperCluster(t) + defer sc.Shutdown() + + s, err := NewSurveyor(getTestOptions()) + if err != nil { + t.Fatalf("couldn't create surveyor: %v", err) + } + if err = s.Start(); err != nil { + t.Fatalf("start error: %v", err) + } + s.Stop() + if err = s.Start(); err != nil { + t.Fatalf("second start error: %v", err) + } + s.Stop() +} + func TestSurveyor_Account(t *testing.T) { sc := st.NewSuperCluster(t) defer sc.Shutdown() @@ -262,7 +280,13 @@ func TestSurveyor_ClientTLSFail(t *testing.T) { opts.CertFile = clientCert opts.KeyFile = clientKey - _, err := NewSurveyor(opts) + s, err := NewSurveyor(opts) + if err != nil { + t.Fatalf("couldn't create surveyor: %v", err) + } + err = s.Start() + defer s.Stop() + if err == nil { t.Fatalf("Connected to a server that required TLS") } @@ -357,11 +381,11 @@ func TestSurveyor_UserPass(t *testing.T) { func TestSurveyor_NoServer(t *testing.T) { s, err := NewSurveyor(getTestOptions()) - defer func() { - if s != nil { - s.Stop() - } - }() + if err != nil { + t.Fatalf("couldn't create surveyor: %v", err) + } + err = s.Start() + defer s.Stop() if err == nil { t.Fatalf("didn't get expected error") @@ -394,18 +418,17 @@ func TestSurveyor_Observations(t *testing.T) { sc := st.NewSuperCluster(t) defer sc.Shutdown() - opt := getTestOptions() - opt.ObservationConfigDir = "testdata/goodobs" + opts := getTestOptions() + opts.ObservationConfigDir = "testdata/goodobs" - s, err := NewSurveyor(opt) + s, err := NewSurveyor(opts) if err != nil { t.Fatalf("couldn't create surveyor: %v", err) } - defer s.Stop() - if err = s.Start(); err != nil { t.Fatalf("start error: %v", err) } + defer s.Stop() if ptu.ToFloat64(s.observationMetrics.observationsGauge) != 1 { t.Fatalf("process error: observations not started") @@ -420,7 +443,6 @@ func TestSurveyor_ConcurrentBlock(t *testing.T) { if err != nil { t.Fatalf("couldn't create surveyor: %v", err) } - if err = s.Start(); err != nil { t.Fatalf("start error: %v", err) } @@ -446,17 +468,22 @@ func TestSurveyor_NATSUserPass(t *testing.T) { opts.NATSUser = "invalid_user" opts.NATSPassword = "password" - _, err := NewSurveyor(opts) + s, err := NewSurveyor(opts) + if err != nil { + t.Fatalf("couldn't create surveyor: %v", err) + } + err = s.Start() if err == nil { t.Fatalf("didn't receive expected error") } if !strings.Contains(err.Error(), "Auth") { t.Fatalf("didn't receive expected error: %v", err) } + s.Stop() opts.NATSUser = "sys" opts.NATSPassword = "password" - s, err := NewSurveyor(opts) + s, err = NewSurveyor(opts) if err != nil { t.Fatalf("couldn't create surveyor: %v", err) }