diff --git a/CHANGELOG.md b/CHANGELOG.md index fb8cb3d14a1..8a3456a3144 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,7 +47,7 @@ and this project adheres to - Improved HTTP requests handling and timeouts ([#2343]). - Our snap package now uses the `core20` image as its base ([#2306]). - New build system and various internal improvements ([#2271], [#2276], [#2297], - [#2509]). + [#2509], [#2552]). [#2231]: https://github.com/AdguardTeam/AdGuardHome/issues/2231 [#2271]: https://github.com/AdguardTeam/AdGuardHome/issues/2271 @@ -59,6 +59,7 @@ and this project adheres to [#2391]: https://github.com/AdguardTeam/AdGuardHome/issues/2391 [#2394]: https://github.com/AdguardTeam/AdGuardHome/issues/2394 [#2509]: https://github.com/AdguardTeam/AdGuardHome/issues/2509 +[#2552]: https://github.com/AdguardTeam/AdGuardHome/issues/2552 [#2589]: https://github.com/AdguardTeam/AdGuardHome/issues/2589 ### Deprecated diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index da223ebd4ef..0ab24f60942 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -13,6 +13,7 @@ import ( "runtime" "strconv" "strings" + "time" "github.com/AdguardTeam/AdGuardHome/internal/util" @@ -268,6 +269,9 @@ func copyInstallSettings(dst, src *configuration) { dst.DNS.Port = src.DNS.Port } +// shutdownTimeout is the timeout for shutting HTTP server down operation. +const shutdownTimeout = 5 * time.Second + // Apply new configuration, start DNS server, restart Web server func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { newSettings := applyConfigReq{} @@ -320,6 +324,10 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { config.DNS.BindHost = newSettings.DNS.IP config.DNS.Port = newSettings.DNS.Port + // TODO(e.burkov): StartMods() should be put in a separate goroutine at + // the moment we'll allow setting up TLS in the initial configuration or + // the configuration itself will use HTTPS protocol, because the + // underlying functions potentially restart the HTTPS server. err = StartMods() if err != nil { Context.firstRun = true @@ -351,16 +359,22 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { f.Flush() } - // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block - // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely + // The Shutdown() method of (*http.Server) needs to be called in a + // separate goroutine, because it waits until all requests are handled + // and will be blocked by it's own caller. if restartHTTP { - go func() { - _ = web.httpServer.Shutdown(context.TODO()) - }() + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + + shut := func(srv *http.Server) { + defer cancel() + err := srv.Shutdown(ctx) + if err != nil { + log.Debug("error while shutting down HTTP server: %s", err) + } + } + go shut(web.httpServer) if web.httpServerBeta != nil { - go func() { - _ = web.httpServerBeta.Shutdown(context.TODO()) - }() + go shut(web.httpServerBeta) } } } diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 327f2f908b1..8b3f1bf519f 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -1,6 +1,7 @@ package home import ( + "context" "encoding/json" "errors" "net/http" @@ -90,7 +91,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { } } -// Perform an update procedure to the latest available version +// handleUpdate performs an update to the latest available version procedure. func handleUpdate(w http.ResponseWriter, _ *http.Request) { if Context.updater.NewVersion() == "" { httpError(w, http.StatusBadRequest, "/update request isn't allowed now") @@ -108,7 +109,13 @@ func handleUpdate(w http.ResponseWriter, _ *http.Request) { f.Flush() } - go finishUpdate() + // The background context is used because the underlying functions wrap + // it with timeout and shut down the server, which handles current + // request. It also should be done in a separate goroutine due to the + // same reason. + go func() { + finishUpdate(context.Background()) + }() } // versionResponse is the response for /control/version.json endpoint. @@ -140,10 +147,10 @@ func (vr *versionResponse) confirmAutoUpdate() { } } -// Complete an update procedure -func finishUpdate() { +// finishUpdate completes an update procedure. +func finishUpdate(ctx context.Context) { log.Info("Stopping all tasks") - cleanup() + cleanup(ctx) cleanupAlways() exeName := "AdGuardHome" diff --git a/internal/home/home.go b/internal/home/home.go index 4ece8f85924..3c068912c4f 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -108,7 +108,7 @@ func Main() { Context.tls.Reload() default: - cleanup() + cleanup(context.Background()) cleanupAlways() os.Exit(0) } @@ -334,7 +334,7 @@ func run(args options) { select {} } -// StartMods - initialize and start DNS after installation +// StartMods initializes and starts the DNS server after installation. func StartMods() error { err := initDNSServer() if err != nil { @@ -501,11 +501,12 @@ func configureLogger(args options) { } } -func cleanup() { +// cleanup stops and resets all the modules. +func cleanup(ctx context.Context) { log.Info("Stopping AdGuard Home") if Context.web != nil { - Context.web.Close() + Context.web.Close(ctx) Context.web = nil } if Context.auth != nil { diff --git a/internal/home/home_test.go b/internal/home/home_test.go index 344657e9efa..5f3f3d81462 100644 --- a/internal/home/home_test.go +++ b/internal/home/home_test.go @@ -186,6 +186,6 @@ func TestHome(t *testing.T) { time.Sleep(1 * time.Second) } - cleanup() + cleanup(context.Background()) cleanupAlways() } diff --git a/internal/home/tls.go b/internal/home/tls.go index de6dcd186b4..1f68cc3950f 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -1,6 +1,7 @@ package home import ( + "context" "crypto" "crypto/ecdsa" "crypto/rsa" @@ -92,7 +93,7 @@ func (t *TLSMod) setCertFileTime() { t.certLastMod = fi.ModTime().UTC() } -// Start - start the module +// Start updates the configuration of TLSMod and starts it. func (t *TLSMod) Start() { if !tlsWebHandlersRegistered { tlsWebHandlersRegistered = true @@ -102,10 +103,14 @@ func (t *TLSMod) Start() { t.confLock.Lock() tlsConf := t.conf t.confLock.Unlock() - Context.web.TLSConfigChanged(tlsConf) + + // The background context is used because the TLSConfigChanged wraps + // context with timeout on its own and shuts down the server, which + // handles current request. + Context.web.TLSConfigChanged(context.Background(), tlsConf) } -// Reload - reload certificate file +// Reload updates the configuration of TLSMod and restarts it. func (t *TLSMod) Reload() { t.confLock.Lock() tlsConf := t.conf @@ -139,7 +144,10 @@ func (t *TLSMod) Reload() { t.confLock.Lock() tlsConf = t.conf t.confLock.Unlock() - Context.web.TLSConfigChanged(tlsConf) + // The background context is used because the TLSConfigChanged wraps + // context with timeout on its own and shuts down the server, which + // handles current request. + Context.web.TLSConfigChanged(context.Background(), tlsConf) } // Set certificate and private key data @@ -296,11 +304,13 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { f.Flush() } - // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block - // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely + // The background context is used because the TLSConfigChanged wraps + // context with timeout on its own and shuts down the server, which + // handles current request. It is also should be done in a separate + // goroutine due to the same reason. if restartHTTPS { go func() { - Context.web.TLSConfigChanged(data) + Context.web.TLSConfigChanged(context.Background(), data) }() } } diff --git a/internal/home/web.go b/internal/home/web.go index 72a2eb62198..0048f42755a 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -122,8 +122,9 @@ func WebCheckPortAvailable(port int) bool { return true } -// TLSConfigChanged - called when TLS configuration has changed -func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) { +// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server +// if necessary. +func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) { log.Debug("Web: applying new TLS configuration") web.conf.PortHTTPS = tlsConf.PortHTTPS web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0) @@ -143,7 +144,12 @@ func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) { web.httpsServer.cond.L.Lock() if web.httpsServer.server != nil { - _ = web.httpsServer.server.Shutdown(context.TODO()) + ctx, cancel := context.WithTimeout(ctx, shutdownTimeout) + err = web.httpsServer.server.Shutdown(ctx) + cancel() + if err != nil { + log.Debug("error while shutting down HTTP server: %s", err) + } } web.httpsServer.enabled = enabled web.httpsServer.cert = cert @@ -198,22 +204,28 @@ func (web *Web) Start() { } } -// Close - stop HTTP server, possibly waiting for all active connections to be closed -func (web *Web) Close() { +// Close gracefully shuts down the HTTP servers. +func (web *Web) Close(ctx context.Context) { log.Info("Stopping HTTP server...") web.httpsServer.cond.L.Lock() web.httpsServer.shutdown = true web.httpsServer.cond.L.Unlock() - if web.httpsServer.server != nil { - _ = web.httpsServer.server.Shutdown(context.TODO()) - } - if web.httpServer != nil { - _ = web.httpServer.Shutdown(context.TODO()) - } - if web.httpServerBeta != nil { - _ = web.httpServerBeta.Shutdown(context.TODO()) + + shut := func(srv *http.Server) { + if srv == nil { + return + } + ctx, cancel := context.WithTimeout(ctx, shutdownTimeout) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + log.Debug("error while shutting down HTTP server: %s", err) + } } + shut(web.httpsServer.server) + shut(web.httpServer) + shut(web.httpServerBeta) + log.Info("Stopped HTTP server") } diff --git a/internal/home/whois.go b/internal/home/whois.go index 6c40ed54e98..1d849673b62 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -35,19 +35,20 @@ type Whois struct { ipAddrs cache.Cache } -// Create module context +// initWhois creates the Whois module context. func initWhois(clients *clientsContainer) *Whois { - w := Whois{} - w.timeoutMsec = 5000 - w.clients = clients - - cconf := cache.Config{} - cconf.EnableLRU = true - cconf.MaxCount = 10000 - w.ipAddrs = cache.New(cconf) + w := Whois{ + timeoutMsec: 5000, + clients: clients, + ipAddrs: cache.New(cache.Config{ + EnableLRU: true, + MaxCount: 10000, + }), + ipChan: make(chan net.IP, 255), + } - w.ipChan = make(chan net.IP, 255) go w.workerLoop() + return &w } @@ -120,12 +121,12 @@ func whoisParse(data string) map[string]string { const MaxConnReadSize = 64 * 1024 // Send request to a server and receive the response -func (w *Whois) query(target, serverAddr string) (string, error) { +func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) { addr, _, _ := net.SplitHostPort(serverAddr) if addr == "whois.arin.net" { target = "n + " + target } - conn, err := customDialContext(context.TODO(), "tcp", serverAddr) + conn, err := customDialContext(ctx, "tcp", serverAddr) if err != nil { return "", err } @@ -153,11 +154,11 @@ func (w *Whois) query(target, serverAddr string) (string, error) { } // Query WHOIS servers (handle redirects) -func (w *Whois) queryAll(target string) (string, error) { +func (w *Whois) queryAll(ctx context.Context, target string) (string, error) { server := net.JoinHostPort(defaultServer, defaultPort) const maxRedirects = 5 for i := 0; i != maxRedirects; i++ { - resp, err := w.query(target, server) + resp, err := w.query(ctx, target, server) if err != nil { return "", err } @@ -183,9 +184,9 @@ func (w *Whois) queryAll(target string) (string, error) { } // Request WHOIS information -func (w *Whois) process(ip net.IP) [][]string { +func (w *Whois) process(ctx context.Context, ip net.IP) [][]string { data := [][]string{} - resp, err := w.queryAll(ip.String()) + resp, err := w.queryAll(ctx, ip.String()) if err != nil { log.Debug("Whois: error: %s IP:%s", err, ip) return data @@ -232,12 +233,13 @@ func (w *Whois) Begin(ip net.IP) { } } -// Get IP address from channel; get WHOIS info; associate info with a client +// workerLoop processes the IP addresses it got from the channel and associates +// the retrieving WHOIS info with a client. func (w *Whois) workerLoop() { for { ip := <-w.ipChan - info := w.process(ip) + info := w.process(context.Background(), ip) if len(info) == 0 { continue } diff --git a/internal/home/whois_test.go b/internal/home/whois_test.go index f8109e9268d..a160cdda489 100644 --- a/internal/home/whois_test.go +++ b/internal/home/whois_test.go @@ -1,6 +1,7 @@ package home import ( + "context" "testing" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -19,7 +20,7 @@ func TestWhois(t *testing.T) { assert.Nil(t, prepareTestDNSServer()) w := Whois{timeoutMsec: 5000} - resp, err := w.queryAll("8.8.8.8") + resp, err := w.queryAll(context.Background(), "8.8.8.8") assert.Nil(t, err) m := whoisParse(resp) assert.Equal(t, "Google LLC", m["orgname"])