diff --git a/control/control.go b/control/control.go index 847236429..cdf7a50f2 100644 --- a/control/control.go +++ b/control/control.go @@ -89,6 +89,10 @@ type pluginControl struct { pluginTrust int keyringFiles []string + // used to cleanly shutdown the GRPC server + grpcServer *grpc.Server + closingChan chan bool + wg sync.WaitGroup } type runsPlugins interface { @@ -336,12 +340,21 @@ func (p *pluginControl) Start() error { } opts := []grpc.ServerOption{} - grpcServer := grpc.NewServer(opts...) - rpc.RegisterMetricManagerServer(grpcServer, &ControlGRPCServer{p}) + p.closingChan = make(chan bool, 1) + p.grpcServer = grpc.NewServer(opts...) + rpc.RegisterMetricManagerServer(p.grpcServer, &ControlGRPCServer{p}) + p.wg.Add(1) go func() { - err := grpcServer.Serve(lis) + defer p.wg.Done() + err := p.grpcServer.Serve(lis) if err != nil { - controlLogger.Fatal(err) + select { + case <-p.closingChan: + // If we called Stop() then there will be a value in p.closingChan, so + // we'll get here and we can exit without showing the error. + default: + controlLogger.Fatal(err) + } } }() @@ -349,10 +362,16 @@ func (p *pluginControl) Start() error { } func (p *pluginControl) Stop() { + // set the Started flag to false (since we're stopping the server) p.Started = false - controlLogger.WithFields(log.Fields{ - "_block": "stop", - }).Info("control stopped") + + // and add a boolean to the p.closingChan (used for error handling in the + // goroutine that is listening for connections) + p.closingChan <- true + + // stop GRPC server + p.grpcServer.Stop() + p.wg.Wait() // stop runner err := p.pluginRunner.Stop() @@ -368,6 +387,12 @@ func (p *pluginControl) Stop() { // unload plugins p.pluginManager.teardown() + + // log that we've stopped the control module + controlLogger.WithFields(log.Fields{ + "_block": "stop", + }).Info("control stopped") + } // Load is the public method to load a plugin into diff --git a/mgmt/rest/server.go b/mgmt/rest/server.go index 402efbfc5..35bd8d16d 100644 --- a/mgmt/rest/server.go +++ b/mgmt/rest/server.go @@ -20,6 +20,7 @@ limitations under the License. package rest import ( + "crypto/tls" "encoding/json" "errors" "fmt" @@ -169,7 +170,7 @@ type Server struct { mc managesConfig n *negroni.Negroni r *httprouter.Router - tls *tls + snapTLS *snapTLS auth bool authpwd string addrString string @@ -177,6 +178,9 @@ type Server struct { wg sync.WaitGroup killChan chan struct{} err chan error + // the following instance variables are used to cleanly shutdown the server + serverListener net.Listener + closingChan chan bool } // New creates a REST API server with a given config @@ -192,7 +196,7 @@ func New(cfg *Config) (*Server, error) { } if https { var err error - s.tls, err = newtls(cpath, kpath) + s.snapTLS, err = newtls(cpath, kpath) if err != nil { return nil, err } @@ -324,6 +328,7 @@ func (s *Server) Name() string { } func (s *Server) Start() error { + s.closingChan = make(chan bool, 1) s.addRoutes() s.run(s.addrString) restLogger.WithFields(log.Fields{ @@ -333,8 +338,19 @@ func (s *Server) Start() error { } func (s *Server) Stop() { + // add a boolean to the s.closingChan (used for error handling in the + // goroutine that is listening for connections) + s.closingChan <- true + // then close the server close(s.killChan) + // close the server listener + s.serverListener.Close() + // wait for the server goroutines to complete (serve and watch) s.wg.Wait() + // finally log the result + restLogger.WithFields(log.Fields{ + "_block": "stop", + }).Info("REST stopped") } func (s *Server) Err() <-chan error { @@ -347,31 +363,59 @@ func (s *Server) Port() int { func (s *Server) run(addrString string) { restLogger.Info("Starting REST API on ", addrString) - if s.tls != nil { - go s.serveTLS(addrString) + if s.snapTLS != nil { + cer, err := tls.LoadX509KeyPair(s.snapTLS.cert, s.snapTLS.key) + if err != nil { + s.err <- err + return + } + config := &tls.Config{Certificates: []tls.Certificate{cer}} + ln, err := tls.Listen("tcp", addrString, config) + if err != nil { + s.err <- err + } + s.serverListener = ln + s.wg.Add(1) + go s.serveTLS(ln) } else { ln, err := net.Listen("tcp", addrString) if err != nil { s.err <- err } + s.serverListener = ln s.addr = ln.Addr() + s.wg.Add(1) go s.serve(ln) } } -func (s *Server) serveTLS(addrString string) { - err := http.ListenAndServeTLS(addrString, s.tls.cert, s.tls.key, s.n) +func (s *Server) serveTLS(ln net.Listener) { + defer s.wg.Done() + err := http.Serve(ln, s.n) if err != nil { - restLogger.Error(err) - s.err <- err + select { + case <-s.closingChan: + // If we called Stop() then there will be a value in s.closingChan, so + // we'll get here and we can exit without showing the error. + default: + restLogger.Error(err) + s.err <- err + } } } func (s *Server) serve(ln net.Listener) { + defer s.wg.Done() err := http.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}, s.n) if err != nil { - restLogger.Error(err) - s.err <- err + select { + case <-s.closingChan: + // If we called Stop() then there will be a value in s.closingChan, so + // we'll get here and we can exit without showing the error. + default: + restLogger.Error(err) + s.err <- err + } } } diff --git a/mgmt/rest/tls.go b/mgmt/rest/snapTLS.go similarity index 95% rename from mgmt/rest/tls.go rename to mgmt/rest/snapTLS.go index 65d6986b3..4bcf3b3bb 100644 --- a/mgmt/rest/tls.go +++ b/mgmt/rest/snapTLS.go @@ -31,12 +31,12 @@ import ( "time" ) -type tls struct { +type snapTLS struct { cert, key string } -func newtls(certPath, keyPath string) (*tls, error) { - t := &tls{} +func newtls(certPath, keyPath string) (*snapTLS, error) { + t := &snapTLS{} if certPath != "" && keyPath != "" { cert, err := os.Open(certPath) if err != nil { @@ -78,7 +78,7 @@ func newtls(certPath, keyPath string) (*tls, error) { return t, nil } -func generateCert(t *tls) error { +func generateCert(t *snapTLS) error { // good for 1 year notBefore := time.Now() notAfter := notBefore.Add(time.Hour * 24 * 365) diff --git a/snapd.go b/snapd.go index 912741664..5f1ee60f3 100644 --- a/snapd.go +++ b/snapd.go @@ -83,6 +83,9 @@ var ( gitversion string coreModules []coreModule + // used to save a reference to the CLi App + cliApp *cli.App + // log levels l = map[int]string{ 1: "debug", @@ -192,11 +195,11 @@ func main() { gitversion = "unknown" } - app := cli.NewApp() - app.Name = "snapd" - app.Version = gitversion - app.Usage = "A powerful telemetry framework" - app.Flags = []cli.Flag{ + cliApp = cli.NewApp() + cliApp.Name = "snapd" + cliApp.Version = gitversion + cliApp.Usage = "A powerful telemetry framework" + cliApp.Flags = []cli.Flag{ flLogLevel, flLogPath, flLogTruncate, @@ -204,13 +207,14 @@ func main() { flMaxProcs, flConfig, } - app.Flags = append(app.Flags, control.Flags...) - app.Flags = append(app.Flags, scheduler.Flags...) - app.Flags = append(app.Flags, rest.Flags...) - app.Flags = append(app.Flags, tribe.Flags...) + cliApp.Flags = append(cliApp.Flags, control.Flags...) + cliApp.Flags = append(cliApp.Flags, scheduler.Flags...) + cliApp.Flags = append(cliApp.Flags, rest.Flags...) + cliApp.Flags = append(cliApp.Flags, tribe.Flags...) + + cliApp.Action = action - app.Action = action - if app.Run(os.Args) != nil { + if cliApp.Run(os.Args) != nil { os.Exit(1) } } @@ -357,7 +361,8 @@ func action(ctx *cli.Context) error { log.Info("REST API is disabled") } - // Set interrupt handling so we can die gracefully. + // Set interrupt handling so we can either restart the app on a SIGHUP or + // die gracefully when an interrupt, kill, etc. are received startInterruptHandling(coreModules...) // Start our modules @@ -945,7 +950,7 @@ func printErrorAndExit(name string, err error) { func startInterruptHandling(modules ...coreModule) { c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGTERM) + signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGTERM, syscall.SIGHUP) //Let's block until someone tells us to quit go func() { @@ -965,13 +970,28 @@ func startInterruptHandling(modules ...coreModule) { }).Info("stopping module") m.Stop() } - log.WithFields( - log.Fields{ - "block": "main", - "_module": "snapd", - "signal": sig.String(), - }).Info("exiting on signal") - os.Exit(0) + if sig == syscall.SIGHUP { + // log the action we're taking (restarting the app) + log.WithFields( + log.Fields{ + "block": "main", + "_module": "snapd", + "signal": sig.String(), + }).Info("restarting app") + // and restart the app (with the current configuration) + err := cliApp.Run(os.Args) + if err != nil { + os.Exit(1) + } + } else { + log.WithFields( + log.Fields{ + "block": "main", + "_module": "snapd", + "signal": sig.String(), + }).Info("exiting on signal") + os.Exit(0) + } }() }